diff --git a/.dotnet/scripts/Add-Customizations.ps1 b/.dotnet/scripts/Add-Customizations.ps1 index 4b4d4cd92..d0351d130 100644 --- a/.dotnet/scripts/Add-Customizations.ps1 +++ b/.dotnet/scripts/Add-Customizations.ps1 @@ -38,11 +38,15 @@ function Set-LangVersionToLatest { $xml.Save($filePath) } -function Edit-RunObjectSerialization { +function Edit-DateTimeOffsetSerialization { + param( + [string]$filename + ) + $root = Split-Path $PSScriptRoot -Parent $directory = Join-Path -Path $root -ChildPath "src\Generated\Models" - $file = Get-ChildItem -Path $directory -Filter "RunObject.Serialization.cs" + $file = Get-ChildItem -Path $directory -Filter $filename $content = Get-Content -Path $file -Raw Write-Output "Editing $($file.FullName)" @@ -52,6 +56,7 @@ function Edit-RunObjectSerialization { $content = $content -creplace "cancelledAt = property\.Value\.GetDateTimeOffset\(`"O`"\);", "// BUG: https://github.com/Azure/autorest.csharp/issues/4296`r`n // cancelledAt = property.Value.GetDateTimeOffset(`"O`");`r`n cancelledAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64());" $content = $content -creplace "failedAt = property\.Value\.GetDateTimeOffset\(`"O`"\);", "// BUG: https://github.com/Azure/autorest.csharp/issues/4296`r`n // failedAt = property.Value.GetDateTimeOffset(`"O`");`r`n failedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64());" $content = $content -creplace "completedAt = property\.Value\.GetDateTimeOffset\(`"O`"\);", "// BUG: https://github.com/Azure/autorest.csharp/issues/4296`r`n // completedAt = property.Value.GetDateTimeOffset(`"O`");`r`n completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64());" + $content = $content -creplace "incompleteAt = property\.Value\.GetDateTimeOffset\(`"O`"\);", "// BUG: https://github.com/Azure/autorest.csharp/issues/4296`r`n // completedAt = property.Value.GetDateTimeOffset(`"O`");`r`n completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64());" $content | Set-Content -Path $file.FullName -NoNewline } @@ -59,4 +64,5 @@ function Edit-RunObjectSerialization { Update-SystemTextJsonPackage Update-MicrosoftBclAsyncInterfacesPackage Set-LangVersionToLatest -Edit-RunObjectSerialization \ No newline at end of file +Edit-DateTimeOffsetSerialization -filename "RunObject.Serialization.cs" +Edit-DateTimeOffsetSerialization -filename "MessageObject.Serialization.cs" diff --git a/.dotnet/src/Custom/Assistants/AssistantClient.Protocol.cs b/.dotnet/src/Custom/Assistants/AssistantClient.Protocol.cs index 826d90bf8..0fb6ac5a0 100644 --- a/.dotnet/src/Custom/Assistants/AssistantClient.Protocol.cs +++ b/.dotnet/src/Custom/Assistants/AssistantClient.Protocol.cs @@ -1,6 +1,8 @@ +using System; using System.ClientModel; using System.ClientModel.Primitives; using System.ComponentModel; +using System.Text; using System.Threading.Tasks; namespace OpenAI.Assistants; @@ -336,6 +338,15 @@ public virtual ClientResult CreateRun( RequestOptions options = null) => RunShim.CreateRun(threadId, content, options); + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult CreateRunStreaming(string threadId, BinaryContent content, RequestOptions options = null) + { + PipelineMessage message = CreateCreateRunRequest(threadId, content, stream: true, options); + RunShim.Pipeline.Send(message); + return ClientResult.FromResponse(message.ExtractResponse()); + } + /// [EditorBrowsable(EditorBrowsableState.Never)] public virtual async Task CreateRunAsync( @@ -344,6 +355,15 @@ public virtual async Task CreateRunAsync( RequestOptions options = null) => await RunShim.CreateRunAsync(threadId, content, options).ConfigureAwait(false); + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task CreateRunStreamingAsync(string threadId, BinaryContent content, RequestOptions options = null) + { + PipelineMessage message = CreateCreateRunRequest(threadId, content, stream: true, options); + await RunShim.Pipeline.SendAsync(message); + return ClientResult.FromResponse(message.ExtractResponse()); + } + /// [EditorBrowsable(EditorBrowsableState.Never)] public virtual ClientResult CreateThreadAndRun( @@ -351,6 +371,15 @@ public virtual ClientResult CreateThreadAndRun( RequestOptions options = null) => RunShim.CreateThreadAndRun(content, options); + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult CreateThreadAndRunStreaming(BinaryContent content, RequestOptions options = null) + { + PipelineMessage message = CreateCreateThreadAndRunRequest(content, stream: true, options); + RunShim.Pipeline.Send(message); + return ClientResult.FromResponse(message.ExtractResponse()); + } + /// [EditorBrowsable(EditorBrowsableState.Never)] public virtual async Task CreateThreadAndRunAsync( @@ -358,6 +387,15 @@ public virtual async Task CreateThreadAndRunAsync( RequestOptions options = null) => await RunShim.CreateThreadAndRunAsync(content, options).ConfigureAwait(false); + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task CreateThreadAndRunStreamingAsync(BinaryContent content, RequestOptions options = null) + { + PipelineMessage message = CreateCreateThreadAndRunRequest(content, stream: true, options); + await RunShim.Pipeline.SendAsync(message); + return ClientResult.FromResponse(message.ExtractResponse()); + } + /// [EditorBrowsable(EditorBrowsableState.Never)] public virtual ClientResult GetRun( @@ -439,6 +477,19 @@ public virtual ClientResult SubmitToolOutputs( RequestOptions options = null) => RunShim.SubmitToolOuputsToRun(threadId, runId, content, options); + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult SubmitToolOutputsStreaming( + string threadId, + string runId, + BinaryContent content, + RequestOptions options = null) + { + PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, content, stream: true, options); + RunShim.Pipeline.Send(message); + return ClientResult.FromResponse(message.ExtractResponse()); + } + /// [EditorBrowsable(EditorBrowsableState.Never)] public virtual async Task SubmitToolOutputsAsync( @@ -448,6 +499,19 @@ public virtual async Task SubmitToolOutputsAsync( RequestOptions options = null) => await RunShim.SubmitToolOuputsToRunAsync(threadId, runId, content, options).ConfigureAwait(false); + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task SubmitToolOutputsStreamingAsync( + string threadId, + string runId, + BinaryContent content, + RequestOptions options = null) + { + PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, content, stream: true, options); + await RunShim.Pipeline.SendAsync(message); + return ClientResult.FromResponse(message.ExtractResponse()); + } + /// public virtual ClientResult GetRunStep( string threadId, @@ -485,4 +549,33 @@ public virtual async Task GetRunStepsAsync( string subsequentStepId, RequestOptions options) => await RunShim.GetRunStepsAsync(threadId, runId, maxResults, createdSortOrder, previousStepId, subsequentStepId, options).ConfigureAwait(false); + + internal PipelineMessage CreateCreateRunRequest(string threadId, BinaryContent content, bool? stream = null, RequestOptions options = null) + => CreateAssistantProtocolRequest($"/threads/{threadId}/runs", content, stream, options); + + internal PipelineMessage CreateCreateThreadAndRunRequest(BinaryContent content, bool? stream = null, RequestOptions options = null) + => CreateAssistantProtocolRequest($"/threads/runs", content, stream, options); + + internal PipelineMessage CreateSubmitToolOutputsRequest(string threadId, string runId, BinaryContent content, bool? stream = null, RequestOptions options = null) + => CreateAssistantProtocolRequest($"/threads/{threadId}/runs/{runId}/submit_tool_outputs", content, stream, options); + + internal PipelineMessage CreateAssistantProtocolRequest(string path, BinaryContent content, bool? stream = null, RequestOptions options = null) + { + PipelineMessage message = Shim.Pipeline.CreateMessage(); + message.ResponseClassifier = ResponseErrorClassifier200; + if (stream == true) + { + message.BufferResponse = false; + } + PipelineRequest request = message.Request; + request.Method = "POST"; + UriBuilder uriBuilder = new(_clientConnector.Endpoint.AbsoluteUri); + uriBuilder.Path += path; + request.Uri = uriBuilder.Uri; + request.Headers.Set("Content-Type", "application/json"); + request.Headers.Set("Accept", stream == true ? "text/event-stream" : "application/json"); + request.Content = content; + message.Apply(options ?? new()); + return message; + } } diff --git a/.dotnet/src/Custom/Assistants/AssistantClient.cs b/.dotnet/src/Custom/Assistants/AssistantClient.cs index 7543c0330..96082108f 100644 --- a/.dotnet/src/Custom/Assistants/AssistantClient.cs +++ b/.dotnet/src/Custom/Assistants/AssistantClient.cs @@ -1,8 +1,9 @@ +using OpenAI.Chat; +using OpenAI.Internal; using System; using System.ClientModel; using System.ClientModel.Primitives; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Threading.Tasks; namespace OpenAI.Assistants; @@ -10,7 +11,6 @@ namespace OpenAI.Assistants; /// /// The service client for OpenAI assistants. /// -[Experimental("OPENAI001")] public partial class AssistantClient { private OpenAIClientConnector _clientConnector; @@ -233,7 +233,7 @@ public virtual async Task> CreateThreadAsync( return ClientResult.FromValue(new AssistantThread(internalResult.Value), internalResult.GetRawResponse()); } - public virtual ClientResult GetThread(string threadId) + public virtual ClientResult GetThread(string threadId) { ClientResult internalResult = ThreadShim.GetThread(threadId); return ClientResult.FromValue(new AssistantThread(internalResult.Value), internalResult.GetRawResponse()); @@ -268,13 +268,13 @@ public virtual async Task> ModifyThreadAsync( return ClientResult.FromValue(new AssistantThread(internalResult.Value), internalResult.GetRawResponse()); } - public virtual ClientResult DeleteThread(string threadId) + public virtual ClientResult DeleteThread(string threadId) { ClientResult internalResult = ThreadShim.DeleteThread(threadId); return ClientResult.FromValue(internalResult.Value.Deleted, internalResult.GetRawResponse()); } - public virtual async Task> DeleteThreadAsync(string threadId) + public virtual async Task> DeleteThreadAsync(string threadId) { ClientResult internalResult = await ThreadShim.DeleteThreadAsync(threadId).ConfigureAwait(false); return ClientResult.FromValue(internalResult.Value.Deleted, internalResult.GetRawResponse()); @@ -459,6 +459,26 @@ public virtual async Task> CreateRunAsync( return ClientResult.FromValue(new ThreadRun(internalResult.Value), internalResult.GetRawResponse()); } + public virtual StreamingClientResult CreateRunStreaming( + string threadId, + string assistantId, + RunCreationOptions options = null) + { + using PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options, stream: true); + RunShim.Pipeline.Send(message); + return CreateStreamingRunResult(message); + } + + public virtual async Task> CreateRunStreamingAsync( + string threadId, + string assistantId, + RunCreationOptions options = null) + { + using PipelineMessage message = CreateCreateRunRequest(threadId, assistantId, options, stream: true); + await RunShim.Pipeline.SendAsync(message); + return CreateStreamingRunResult(message); + } + public virtual ClientResult CreateThreadAndRun( string assistantId, ThreadCreationOptions threadOptions = null, @@ -481,6 +501,26 @@ Internal.Models.CreateThreadAndRunRequest request return ClientResult.FromValue(new ThreadRun(internalResult.Value), internalResult.GetRawResponse()); } + public virtual StreamingClientResult CreateThreadAndRunStreaming( + string assistantId, + ThreadCreationOptions threadOptions = null, + RunCreationOptions runOptions = null) + { + using PipelineMessage message = CreateCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true); + Shim.Pipeline.Send(message); + return CreateStreamingRunResult(message); + } + + public virtual async Task> CreateThreadAndRunStreamingAsync( + string assistantId, + ThreadCreationOptions threadOptions = null, + RunCreationOptions runOptions = null) + { + using PipelineMessage message = CreateCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true); + await Shim.Pipeline.SendAsync(message); + return CreateStreamingRunResult(message); + } + public virtual ClientResult GetRun(string threadId, string runId) { ClientResult internalResult = RunShim.GetRun(threadId, runId); @@ -560,7 +600,7 @@ public virtual ClientResult SubmitToolOutputs(string threadId, string requestToolOutputs.Add(new(toolOutput.Id, toolOutput.Output, null)); } - Internal.Models.SubmitToolOutputsRunRequest request = new(requestToolOutputs, null); + Internal.Models.SubmitToolOutputsRunRequest request = new(requestToolOutputs, null, serializedAdditionalRawData: null); ClientResult internalResult = RunShim.SubmitToolOuputsToRun(threadId, runId, request); return ClientResult.FromValue(new ThreadRun(internalResult.Value), internalResult.GetRawResponse()); } @@ -574,11 +614,77 @@ public virtual async Task> SubmitToolOutputsAsync(string requestToolOutputs.Add(new(toolOutput.Id, toolOutput.Output, null)); } - Internal.Models.SubmitToolOutputsRunRequest request = new(requestToolOutputs, null); + Internal.Models.SubmitToolOutputsRunRequest request = new(requestToolOutputs, null, serializedAdditionalRawData: null); ClientResult internalResult = await RunShim.SubmitToolOuputsToRunAsync(threadId, runId, request).ConfigureAwait(false); return ClientResult.FromValue(new ThreadRun(internalResult.Value), internalResult.GetRawResponse()); } + public virtual StreamingClientResult SubmitToolOutputsStreaming(string threadId, string runId, IEnumerable toolOutputs) + { + using PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true); + Shim.Pipeline.SendAsync(message); + return CreateStreamingRunResult(message); + } + + public virtual async Task> SubmitToolOutputsStreamingAsync(string threadId, string runId, IEnumerable toolOutputs) + { + using PipelineMessage message = CreateSubmitToolOutputsRequest(threadId, runId, toolOutputs, stream: true); + await Shim.Pipeline.SendAsync(message); + return CreateStreamingRunResult(message); + } + + internal PipelineMessage CreateCreateRunRequest(string threadId, string assistantId, RunCreationOptions runOptions, bool? stream = null) + { + Internal.Models.CreateRunRequest internalCreateRunRequest = CreateInternalCreateRunRequest(assistantId, runOptions, stream); + BinaryContent requestBody = BinaryContent.Create(internalCreateRunRequest); + return CreateCreateRunRequest(threadId, requestBody, stream: true); + } + + internal PipelineMessage CreateCreateThreadAndRunRequest( + string assistantId, + ThreadCreationOptions threadOptions, + RunCreationOptions runOptions, + bool? stream = null) + { + Internal.Models.CreateThreadAndRunRequest internalRequest + = CreateInternalCreateThreadAndRunRequest(assistantId, threadOptions, runOptions, stream: true); + BinaryContent content = BinaryContent.Create(internalRequest); + return CreateCreateThreadAndRunRequest(content, stream: true); + } + + internal PipelineMessage CreateSubmitToolOutputsRequest(string threadId, string runId, IEnumerable toolOutputs, bool? stream) + { + List requestToolOutputs = []; + foreach (ToolOutput toolOutput in toolOutputs) + { + requestToolOutputs.Add(new(toolOutput.Id, toolOutput.Output, null)); + } + Internal.Models.SubmitToolOutputsRunRequest internalRequest = new(requestToolOutputs, stream, serializedAdditionalRawData: null); + BinaryContent content = BinaryContent.Create(internalRequest); + return CreateSubmitToolOutputsRequest(threadId, runId, content, stream: true); + } + + internal static StreamingClientResult CreateStreamingRunResult(PipelineMessage message) + { + if (message.Response.IsError) + { + throw new ClientResultException(message.Response); + } + + // TODO: why do we need to wrap this in a try-catch? + // Would putting the message in a `using` block suffice? + PipelineResponse response = message.Response; + try + { + // TODO: dust this part up... + return new StreamingAssistantResult(response); + } + finally + { + response?.Dispose(); + } + } + internal static Internal.Models.CreateAssistantRequest CreateInternalCreateAssistantRequest( string modelName, AssistantCreationOptions options) @@ -621,7 +727,8 @@ internal static Internal.Models.CreateThreadRequest CreateInternalCreateThreadRe internal static Internal.Models.CreateRunRequest CreateInternalCreateRunRequest( string assistantId, - RunCreationOptions options = null) + RunCreationOptions options = null, + bool? stream = null) { options ??= new(); return new( @@ -631,13 +738,15 @@ internal static Internal.Models.CreateRunRequest CreateInternalCreateRunRequest( options.AdditionalInstructions, ToInternalBinaryDataList(options.OverrideTools), options.Metadata, + stream, serializedAdditionalRawData: null); } internal static Internal.Models.CreateThreadAndRunRequest CreateInternalCreateThreadAndRunRequest( string assistantId, ThreadCreationOptions threadOptions, - RunCreationOptions runOptions) + RunCreationOptions runOptions, + bool? stream = null) { threadOptions ??= new(); runOptions ??= new(); @@ -649,6 +758,7 @@ internal static Internal.Models.CreateThreadAndRunRequest CreateInternalCreateTh runOptions.OverrideInstructions, ToInternalBinaryDataList(runOptions?.OverrideTools), runOptions?.Metadata, + stream, serializedAdditionalRawData: null); } @@ -717,4 +827,7 @@ internal virtual async Task>> GetListQueryPageAsyn ListQueryPage convertedValue = ListQueryPage.Create(internalResult.Value) as ListQueryPage; return ClientResult.FromValue(convertedValue, internalResult.GetRawResponse()); } + + private static PipelineMessageClassifier _responseErrorClassifier200; + private static PipelineMessageClassifier ResponseErrorClassifier200 => _responseErrorClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); } diff --git a/.dotnet/src/Custom/Assistants/FunctionToolDefinition.cs b/.dotnet/src/Custom/Assistants/FunctionToolDefinition.cs index d6687feec..87c63366b 100644 --- a/.dotnet/src/Custom/Assistants/FunctionToolDefinition.cs +++ b/.dotnet/src/Custom/Assistants/FunctionToolDefinition.cs @@ -82,7 +82,8 @@ internal override void WriteDerived(Utf8JsonWriter writer, ModelReaderWriterOpti if (Optional.IsDefined(Parameters)) { writer.WritePropertyName("parameters"u8); - writer.WriteRawValue(Parameters.ToString()); + using JsonDocument parametersJson = JsonDocument.Parse(Parameters); + parametersJson.WriteTo(writer); } writer.WriteEndObject(); } diff --git a/.dotnet/src/Custom/Assistants/FunctionToolInfo.cs b/.dotnet/src/Custom/Assistants/FunctionToolInfo.cs index b6edc1fc5..f4213e089 100644 --- a/.dotnet/src/Custom/Assistants/FunctionToolInfo.cs +++ b/.dotnet/src/Custom/Assistants/FunctionToolInfo.cs @@ -75,7 +75,9 @@ internal override void WriteDerived(Utf8JsonWriter writer, ModelReaderWriterOpti } if (Optional.IsDefined(Parameters)) { - writer.WriteRawValue(Parameters.ToString()); + writer.WritePropertyName("parameters"u8); + using JsonDocument parametersJson = JsonDocument.Parse(Parameters); + parametersJson.WriteTo(writer); } writer.WriteEndObject(); } diff --git a/.dotnet/src/Custom/Assistants/MessageContent.cs b/.dotnet/src/Custom/Assistants/MessageContent.cs index 928137621..9c41b4e49 100644 --- a/.dotnet/src/Custom/Assistants/MessageContent.cs +++ b/.dotnet/src/Custom/Assistants/MessageContent.cs @@ -3,4 +3,5 @@ namespace OpenAI.Assistants; public abstract partial class MessageContent { + public virtual string GetText() => (this as MessageTextContent).Text; } diff --git a/.dotnet/src/Custom/Assistants/MessageRole.Serialization.cs b/.dotnet/src/Custom/Assistants/MessageRole.Serialization.cs new file mode 100644 index 000000000..eb1a01972 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/MessageRole.Serialization.cs @@ -0,0 +1,29 @@ +using System; +using System.ClientModel.Primitives; +using System.Text.Json; + +namespace OpenAI.Assistants; + +internal class MessageRoleSerialization +{ + public static MessageRole DeserializeMessageRole(JsonElement jsonElement, ModelReaderWriterOptions options = default) + { + if (jsonElement.ValueKind != JsonValueKind.String) + { + throw new ArgumentException(nameof(jsonElement)); + } + string roleName = jsonElement.GetString(); + if (roleName == "assistant") + { + return MessageRole.Assistant; + } + else if (roleName == "user") + { + return MessageRole.User; + } + else + { + throw new NotImplementedException(roleName); + } + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingAssistantResult.cs b/.dotnet/src/Custom/Assistants/StreamingAssistantResult.cs new file mode 100644 index 000000000..060942470 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingAssistantResult.cs @@ -0,0 +1,97 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace OpenAI.Assistants; + +#nullable enable + +internal class StreamingAssistantResult : StreamingClientResult +{ + public StreamingAssistantResult(PipelineResponse response) : base(response) + { + } + + public override IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + // Note: this implementation disposes the stream after the caller has + // enumerated the elements obtained from the stream. That is to say, + // the `await foreach` loop can only happen once -- if it is tried a + // second time, the caller will get an ObjectDisposedException trying + // to access a disposed Stream. + using PipelineResponse response = GetRawResponse(); + + // Extract the content stream from the response to obtain dispose + // ownership of it. This means the content stream will not be disposed + // when the response is disposed. + Stream contentStream = response.ContentStream ?? throw new InvalidOperationException("Cannot enumerate null response ContentStream."); + response.ContentStream = null; + + return new AssistantRunUpdateEnumerator(contentStream); + } + + private class AssistantRunUpdateEnumerator : IAsyncEnumerator + { + private readonly IAsyncEnumerator _sseEvents; + + private List? _currentUpdates; + private int _currentUpdateIndex; + + public AssistantRunUpdateEnumerator(Stream stream) + { + AsyncSseReader reader = new AsyncSseReader(stream); + + // TODO: Pass CancellationToken. + _sseEvents = reader.GetAsyncEnumerator(); + } + + public StreamingUpdate Current => throw new NotImplementedException(); + + public async ValueTask MoveNextAsync() + { + // TODO: Can we wrap the boilerplate parts of this up into a special SSE base type for this? + // would that be public/internal/instantiated based on configuration given to the + // generator from the TSP? + + // Still have leftovers from the last event we pulled from the reader. + if (_currentUpdates is not null && _currentUpdateIndex < _currentUpdates.Count) + { + _currentUpdateIndex++; + return true; + } + + // We either don't have any stored updates, or we've exceeded the + // count of the ones we have. Get the next set. + + // TODO: Call different configure await variant in this context, or no? + if (!await _sseEvents.MoveNextAsync().ConfigureAwait(false)) + { + // Done with events from the stream. + return false; + } + + ServerSentEvent ssEvent = _sseEvents.Current; + + // TODO: optimize + BinaryData data = BinaryData.FromString(new string(ssEvent.Data.ToArray())); + + // TODO: don't instantiate every time. + StreamingUpdateCollection justToCreate = new StreamingUpdateCollection(); + + _currentUpdates = justToCreate.Create(data, ModelReaderWriterOptions.Json); + _currentUpdateIndex = 0; + + return true; + } + + public ValueTask DisposeAsync() + { + // TODO: revisit per platforms where async dispose is available. + _sseEvents?.DisposeAsync(); + return new ValueTask(); + } + } +} \ No newline at end of file diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.Serialization.cs new file mode 100644 index 000000000..e5869c3d9 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.Serialization.cs @@ -0,0 +1,13 @@ +namespace OpenAI.Assistants; + +using System.ClientModel.Primitives; +using System.Text.Json; + +public partial class StreamingMessageCompletion : StreamingUpdate +{ + internal static StreamingMessageCompletion DeserializeStreamingMessageCompletion(JsonElement element, ModelReaderWriterOptions options = default) + { + return new StreamingMessageCompletion( + new ThreadMessage(Internal.Models.MessageObject.DeserializeMessageObject(element, options))); + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.cs b/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.cs new file mode 100644 index 000000000..212e97910 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingMessageCompletion.cs @@ -0,0 +1,11 @@ +namespace OpenAI.Assistants; + +public partial class StreamingMessageCompletion : StreamingUpdate +{ + public ThreadMessage Message { get; } + + internal StreamingMessageCompletion(ThreadMessage message) + { + Message = message; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageCreation.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingMessageCreation.Serialization.cs new file mode 100644 index 000000000..170df7cd1 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingMessageCreation.Serialization.cs @@ -0,0 +1,17 @@ +namespace OpenAI.Assistants; + +using System.ClientModel.Primitives; +using System.Text.Json; + +public partial class StreamingMessageCreation +{ + internal static StreamingUpdate DeserializeSseMessageCreation( + JsonElement sseDataJson, + ModelReaderWriterOptions options = null) + { + Internal.Models.MessageObject internalMessage + = Internal.Models.MessageObject.DeserializeMessageObject(sseDataJson, options); + ThreadMessage message = new(internalMessage); + return new StreamingMessageCreation(message); + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageCreation.cs b/.dotnet/src/Custom/Assistants/StreamingMessageCreation.cs new file mode 100644 index 000000000..c6641903d --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingMessageCreation.cs @@ -0,0 +1,11 @@ +namespace OpenAI.Assistants; + +public partial class StreamingMessageCreation : StreamingUpdate +{ + public ThreadMessage Message { get; } + + internal StreamingMessageCreation(ThreadMessage message) + { + Message = message; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.Serialization.cs new file mode 100644 index 000000000..d31176fdb --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.Serialization.cs @@ -0,0 +1,57 @@ +namespace OpenAI.Assistants; + +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +public partial class StreamingMessageUpdate +{ + internal static IEnumerable DeserializeSseMessageUpdates( + JsonElement sseDataJson, + ModelReaderWriterOptions options = default) + { + List results = []; + if (sseDataJson.ValueKind == JsonValueKind.Null) + { + return results; + } + string id = null; + List<(int Index, MessageContent Content)> indexedContentItems = []; + foreach (JsonProperty property in sseDataJson.EnumerateObject()) + { + if (property.NameEquals("id"u8)) + { + id = property.Value.GetString(); + continue; + } + if (property.NameEquals("delta"u8)) + { + foreach (JsonProperty messageDeltaProperty in property.Value.EnumerateObject()) + { + if (messageDeltaProperty.NameEquals("content"u8)) + { + foreach (JsonElement contentItemElement in messageDeltaProperty.Value.EnumerateArray()) + { + MessageContent contentItem = MessageContent.DeserializeMessageContent(contentItemElement); + foreach (JsonProperty contentItemProperty in contentItemElement.EnumerateObject()) + { + if (contentItemProperty.NameEquals("index"u8)) + { + indexedContentItems.Add((contentItemProperty.Value.GetInt32(), contentItem)); + continue; + } + } + } + continue; + } + } + continue; + } + } + foreach((int index, MessageContent contentItem) in indexedContentItems) + { + results.Add(new StreamingMessageUpdate(contentItem, index)); + } + return results; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.cs b/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.cs new file mode 100644 index 000000000..30b5dacbd --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingMessageUpdate.cs @@ -0,0 +1,15 @@ +namespace OpenAI.Assistants; + +public partial class StreamingMessageUpdate : StreamingUpdate +{ + public MessageContent ContentUpdate { get; } + public int? ContentUpdateIndex { get; } + + internal StreamingMessageUpdate( + MessageContent contentUpdate, + int? contentIndex) + { + ContentUpdate = contentUpdate; + ContentUpdateIndex = contentIndex; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingRequiredAction.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingRequiredAction.Serialization.cs new file mode 100644 index 000000000..e89c6e0c1 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingRequiredAction.Serialization.cs @@ -0,0 +1,19 @@ +namespace OpenAI.Assistants; + +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +public partial class StreamingRequiredAction +{ + internal static IEnumerable DeserializeStreamingRequiredActions(JsonElement sseDataJson, ModelReaderWriterOptions options = null) + { + ThreadRun run = new(Internal.Models.RunObject.DeserializeRunObject(sseDataJson, options)); + List results = []; + foreach (RunRequiredAction deserializedAction in run.RequiredActions) + { + results.Add(new(deserializedAction)); + } + return results; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingRequiredAction.cs b/.dotnet/src/Custom/Assistants/StreamingRequiredAction.cs new file mode 100644 index 000000000..ccd037154 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingRequiredAction.cs @@ -0,0 +1,11 @@ +namespace OpenAI.Assistants; + +public partial class StreamingRequiredAction : StreamingUpdate +{ + public RunRequiredAction RequiredAction { get; } + + internal StreamingRequiredAction(RunRequiredAction requiredAction) + { + RequiredAction = requiredAction; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingRunCreation.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingRunCreation.Serialization.cs new file mode 100644 index 000000000..14bb26f94 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingRunCreation.Serialization.cs @@ -0,0 +1,14 @@ +namespace OpenAI.Assistants; + +using System.ClientModel.Primitives; +using System.Text.Json; + +public partial class StreamingRunCreation +{ + internal static StreamingUpdate DeserializeStreamingRunCreation(JsonElement element, ModelReaderWriterOptions options = null) + { + Internal.Models.RunObject internalRun = Internal.Models.RunObject.DeserializeRunObject(element, options); + ThreadRun run = new(internalRun); + return new StreamingRunCreation(run); + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingRunCreation.cs b/.dotnet/src/Custom/Assistants/StreamingRunCreation.cs new file mode 100644 index 000000000..89e628e43 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingRunCreation.cs @@ -0,0 +1,11 @@ +namespace OpenAI.Assistants; + +public partial class StreamingRunCreation : StreamingUpdate +{ + public ThreadRun Run { get; } + + internal StreamingRunCreation(ThreadRun run) + { + Run = run; + } +} diff --git a/.dotnet/src/Custom/Assistants/StreamingUpdate.Serialization.cs b/.dotnet/src/Custom/Assistants/StreamingUpdate.Serialization.cs new file mode 100644 index 000000000..7ef50fa57 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingUpdate.Serialization.cs @@ -0,0 +1,53 @@ +namespace OpenAI.Assistants; + +using System; +using System.Collections.Generic; +using System.Text.Json; + +public partial class StreamingUpdate +{ + internal static IEnumerable DeserializeSseRunUpdates( + ReadOnlyMemory sseEventName, + JsonElement sseDataJson) + { + IEnumerable results = []; + if (sseEventName.Span.SequenceEqual(s_threadMessageCreationEventName.Span)) + { + results = [StreamingMessageCreation.DeserializeSseMessageCreation(sseDataJson)]; + } + else if (sseEventName.Span.SequenceEqual(s_threadMessageDeltaEventName.Span)) + { + results = StreamingMessageUpdate.DeserializeSseMessageUpdates(sseDataJson); + } + else if (sseEventName.Span.SequenceEqual(s_threadMessageCompletionEventName.Span)) + { + results = [StreamingMessageCompletion.DeserializeStreamingMessageCompletion(sseDataJson)]; + } + else if (sseEventName.Span.SequenceEqual(s_runCreatedEventName.Span)) + { + results = [StreamingRunCreation.DeserializeStreamingRunCreation(sseDataJson)]; + } + + else if (sseEventName.Span.SequenceEqual(s_runRequiredActionEventName.Span)) + { + results = StreamingRequiredAction.DeserializeStreamingRequiredActions(sseDataJson); + } + else + { + results = [new StreamingUpdate()]; + } + + JsonElement rawElementClone = sseDataJson.Clone(); + foreach (StreamingUpdate baseUpdate in results) + { + baseUpdate._originalJson = rawElementClone; + } + return results; + } + + private static readonly ReadOnlyMemory s_threadMessageCreationEventName = "thread.message.created".AsMemory(); + private static readonly ReadOnlyMemory s_threadMessageDeltaEventName = "thread.message.delta".AsMemory(); + private static readonly ReadOnlyMemory s_threadMessageCompletionEventName = "thread.message.completed".AsMemory(); + private static readonly ReadOnlyMemory s_runCreatedEventName = "thread.run.created".AsMemory(); + private static readonly ReadOnlyMemory s_runRequiredActionEventName = "thread.run.requires_action".AsMemory(); +} diff --git a/.dotnet/src/Custom/Assistants/StreamingUpdate.cs b/.dotnet/src/Custom/Assistants/StreamingUpdate.cs new file mode 100644 index 000000000..a25f7eba4 --- /dev/null +++ b/.dotnet/src/Custom/Assistants/StreamingUpdate.cs @@ -0,0 +1,15 @@ +namespace OpenAI.Assistants; +using System.Text.Json; + +/// +/// Represents an incremental item of new data in a streaming response when running a thread with an assistant. +/// +public partial class StreamingUpdate +{ + private JsonElement _originalJson; + public JsonElement GetRawSseEvent() => _originalJson; + + protected StreamingUpdate() + { + } +} diff --git a/.dotnet/src/Custom/Chat/ChatClient.cs b/.dotnet/src/Custom/Chat/ChatClient.cs index 16a875c5a..a4b7b1352 100644 --- a/.dotnet/src/Custom/Chat/ChatClient.cs +++ b/.dotnet/src/Custom/Chat/ChatClient.cs @@ -40,8 +40,8 @@ public ChatClient(string model, ApiKeyCredential credential = default, OpenAICli /// The user message to provide as a prompt for chat completion. /// Additional options for the chat completion request. /// A result for a single chat completion. - public virtual ClientResult CompleteChat(string message, ChatCompletionOptions options = null) - => CompleteChat(new List() { new ChatRequestUserMessage(message) }, options); + public virtual ClientResult CompleteChat(string message, ChatCompletionOptions options = null) + => CompleteChat(new List() { new ChatRequestUserMessage(message) }, options); /// /// Generates a single chat completion result for a single, simple user message. @@ -49,9 +49,9 @@ public virtual ClientResult CompleteChat(string message, ChatCom /// The user message to provide as a prompt for chat completion. /// Additional options for the chat completion request. /// A result for a single chat completion. - public virtual Task> CompleteChatAsync(string message, ChatCompletionOptions options = null) - => CompleteChatAsync( - new List() { new ChatRequestUserMessage(message) }, options); + public virtual Task> CompleteChatAsync(string message, ChatCompletionOptions options = null) + => CompleteChatAsync( + new List() { new ChatRequestUserMessage(message) }, options); /// /// Generates a single chat completion result for a provided set of input chat messages. @@ -63,7 +63,7 @@ public virtual ClientResult CompleteChat( IEnumerable messages, ChatCompletionOptions options = null) { - Internal.Models.CreateChatCompletionRequest request = CreateInternalRequest(messages, options); + Internal.Models.CreateChatCompletionRequest request = CreateInternalRequest(messages, options); ClientResult response = Shim.CreateChatCompletion(request); ChatCompletion chatCompletion = new(response.Value, internalChoiceIndex: 0); return ClientResult.FromValue(chatCompletion, response.GetRawResponse()); @@ -93,7 +93,6 @@ public virtual async Task> CompleteChatAsync( /// The number of independent, alternative response choices that should be generated. /// /// Additional options for the chat completion request. - /// The cancellation token for the operation. /// A result for a single chat completion. public virtual ClientResult CompleteChat( IEnumerable messages, @@ -147,14 +146,14 @@ public virtual async Task> CompleteChatAs /// /// Additional options for the chat completion request. /// A streaming result with incremental chat completion updates. - public virtual StreamingClientResult CompleteChatStreaming( - string message, - int? choiceCount = null, - ChatCompletionOptions options = null) - => CompleteChatStreaming( - new List { new ChatRequestUserMessage(message) }, - choiceCount, - options); + public virtual StreamingClientResult CompleteChatStreaming( + string message, + int? choiceCount = null, + ChatCompletionOptions options = null) + => CompleteChatStreaming( + new List { new ChatRequestUserMessage(message) }, + choiceCount, + options); /// /// Begins a streaming response for a chat completion request using a single, simple user message as input. @@ -191,29 +190,25 @@ public virtual Task> CompleteChatStre /// The number of independent, alternative choices that the chat completion request should generate. /// /// Additional options for the chat completion request. - /// The cancellation token for the operation. /// A streaming result with incremental chat completion updates. public virtual StreamingClientResult CompleteChatStreaming( IEnumerable messages, int? choiceCount = null, ChatCompletionOptions options = null) { - PipelineMessage requestMessage = CreateCustomRequestMessage(messages, choiceCount, options); - requestMessage.BufferResponse = false; - Shim.Pipeline.Send(requestMessage); - PipelineResponse response = requestMessage.ExtractResponse(); + PipelineMessage message = CreateCustomRequestMessage(messages, choiceCount, options); + message.BufferResponse = false; + + Shim.Pipeline.Send(message); + + PipelineResponse response = message.Response; if (response.IsError) { throw new ClientResultException(response); } - ClientResult genericResult = ClientResult.FromResponse(response); - return StreamingClientResult.CreateFromResponse( - genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseStream( - responseForEnumeration.GetRawResponse().ContentStream, - e => StreamingChatUpdate.DeserializeStreamingChatUpdates(e))); + return new StreamingChatResult(response); } /// @@ -229,29 +224,25 @@ public virtual StreamingClientResult CompleteChatStreaming( /// The number of independent, alternative choices that the chat completion request should generate. /// /// Additional options for the chat completion request. - /// The cancellation token for the operation. /// A streaming result with incremental chat completion updates. public virtual async Task> CompleteChatStreamingAsync( IEnumerable messages, int? choiceCount = null, ChatCompletionOptions options = null) { - PipelineMessage requestMessage = CreateCustomRequestMessage(messages, choiceCount, options); - requestMessage.BufferResponse = false; - await Shim.Pipeline.SendAsync(requestMessage).ConfigureAwait(false); - PipelineResponse response = requestMessage.ExtractResponse(); + PipelineMessage message = CreateCustomRequestMessage(messages, choiceCount, options); + message.BufferResponse = false; + + await Shim.Pipeline.SendAsync(message).ConfigureAwait(false); + + PipelineResponse response = message.Response; if (response.IsError) { throw new ClientResultException(response); } - ClientResult genericResult = ClientResult.FromResponse(response); - return StreamingClientResult.CreateFromResponse( - genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseStream( - responseForEnumeration.GetRawResponse().ContentStream, - e => StreamingChatUpdate.DeserializeStreamingChatUpdates(e))); + return new StreamingChatResult(response); } private Internal.Models.CreateChatCompletionRequest CreateInternalRequest( @@ -326,4 +317,4 @@ private PipelineMessage CreateCustomRequestMessage(IEnumerable _responseErrorClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); -} +} \ No newline at end of file diff --git a/.dotnet/src/Custom/Chat/StreamingChatResult.cs b/.dotnet/src/Custom/Chat/StreamingChatResult.cs new file mode 100644 index 000000000..542470b00 --- /dev/null +++ b/.dotnet/src/Custom/Chat/StreamingChatResult.cs @@ -0,0 +1,93 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace OpenAI.Chat; + +#nullable enable + +internal class StreamingChatResult : StreamingClientResult +{ + public StreamingChatResult(PipelineResponse response) : base(response) + { + } + + public override IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + // Note: this implementation disposes the stream after the caller has + // enumerated the elements obtained from the stream. That is to say, + // the `await foreach` loop can only happen once -- if it is tried a + // second time, the caller will get an ObjectDisposedException trying + // to access a disposed Stream. + using PipelineResponse response = GetRawResponse(); + + // Extract the content stream from the response to obtain dispose + // ownership of it. This means the content stream will not be disposed + // when the response is disposed. + Stream contentStream = response.ContentStream ?? throw new InvalidOperationException("Cannot enumerate null response ContentStream."); + response.ContentStream = null; + + return new ChatUpdateEnumerator(contentStream); + } + + private class ChatUpdateEnumerator : IAsyncEnumerator + { + private readonly IAsyncEnumerator _sseEvents; + + private List? _currentUpdates; + private int _currentUpdateIndex; + + public ChatUpdateEnumerator(Stream stream) + { + AsyncSseReader reader = new AsyncSseReader(stream); + + // TODO: Pass CancellationToken. + _sseEvents = reader.GetAsyncEnumerator(); + } + + public StreamingChatUpdate Current => throw new NotImplementedException(); + + public async ValueTask MoveNextAsync() + { + // Still have leftovers from the last event we pulled from the reader. + if (_currentUpdates is not null && _currentUpdateIndex < _currentUpdates.Count) + { + _currentUpdateIndex++; + return true; + } + + // We either don't have any stored updates, or we've exceeded the + // count of the ones we have. Get the next set. + + // TODO: Call different configure await variant in this context, or no? + if (!await _sseEvents.MoveNextAsync().ConfigureAwait(false)) + { + // Done with events from the stream. + return false; + } + + ServerSentEvent ssEvent = _sseEvents.Current; + + // TODO: optimize + BinaryData data = BinaryData.FromString(new string(ssEvent.Data.ToArray())); + + // TODO: don't instantiate every time. + StreamingChatUpdateCollection justToCreate = new StreamingChatUpdateCollection(); + + _currentUpdates = justToCreate.Create(data, ModelReaderWriterOptions.Json); + _currentUpdateIndex = 0; + + return true; + } + + public ValueTask DisposeAsync() + { + // TODO: revisit per platforms where async dispose is available. + _sseEvents?.DisposeAsync(); + return new ValueTask(); + } + } +} \ No newline at end of file diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs index c1540897b..604ff5110 100644 --- a/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs @@ -1,13 +1,14 @@ -namespace OpenAI.Chat; - using System; +using System.ClientModel.Primitives; using System.Collections.Generic; using System.Text.Json; +namespace OpenAI.Chat; + /// /// Represents an incremental item of new data in a streaming response to a chat completion request. /// -public partial class StreamingChatUpdate +public partial class StreamingChatUpdate : IJsonModel { /// /// Gets a unique identifier associated with this streamed Chat Completions response. @@ -182,155 +183,28 @@ internal StreamingChatUpdate( LogProbabilities = logProbabilities; } - internal static List DeserializeStreamingChatUpdates(JsonElement element) + public void Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) { - List results = []; - if (element.ValueKind == JsonValueKind.Null) - { - return results; - } - string id = default; - DateTimeOffset created = default; - string systemFingerprint = null; - foreach (JsonProperty property in element.EnumerateObject()) - { - if (property.NameEquals("id"u8)) - { - id = property.Value.GetString(); - continue; - } - if (property.NameEquals("created"u8)) - { - created = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - continue; - } - if (property.NameEquals("system_fingerprint")) - { - systemFingerprint = property.Value.GetString(); - continue; - } - if (property.NameEquals("choices"u8)) - { - foreach (JsonElement choiceElement in property.Value.EnumerateArray()) - { - ChatRole? role = null; - string contentUpdate = null; - string functionName = null; - string functionArgumentsUpdate = null; - int choiceIndex = 0; - ChatFinishReason? finishReason = null; - List toolCallUpdates = []; - ChatLogProbabilityCollection logProbabilities = null; + throw new NotImplementedException(); + } + + public StreamingChatUpdate Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } - foreach (JsonProperty choiceProperty in choiceElement.EnumerateObject()) - { - if (choiceProperty.NameEquals("index"u8)) - { - choiceIndex = choiceProperty.Value.GetInt32(); - continue; - } - if (choiceProperty.NameEquals("finish_reason"u8)) - { - if (choiceProperty.Value.ValueKind == JsonValueKind.Null) - { - finishReason = null; - continue; - } - finishReason = choiceProperty.Value.GetString() switch - { - "stop" => ChatFinishReason.Stopped, - "length" => ChatFinishReason.Length, - "tool_calls" => ChatFinishReason.ToolCalls, - "function_call" => ChatFinishReason.FunctionCall, - "content_filter" => ChatFinishReason.ContentFilter, - _ => throw new ArgumentException(nameof(finishReason)), - }; - continue; - } - if (choiceProperty.NameEquals("delta"u8)) - { - foreach (JsonProperty deltaProperty in choiceProperty.Value.EnumerateObject()) - { - if (deltaProperty.NameEquals("role"u8)) - { - role = deltaProperty.Value.GetString() switch - { - "system" => ChatRole.System, - "user" => ChatRole.User, - "assistant" => ChatRole.Assistant, - "tool" => ChatRole.Tool, - "function" => ChatRole.Function, - _ => throw new ArgumentException(nameof(role)), - }; - continue; - } - if (deltaProperty.NameEquals("content"u8)) - { - contentUpdate = deltaProperty.Value.GetString(); - continue; - } - if (deltaProperty.NameEquals("function_call"u8)) - { - foreach (JsonProperty functionProperty in deltaProperty.Value.EnumerateObject()) - { - if (functionProperty.NameEquals("name"u8)) - { - functionName = functionProperty.Value.GetString(); - continue; - } - if (functionProperty.NameEquals("arguments"u8)) - { - functionArgumentsUpdate = functionProperty.Value.GetString(); - } - } - } - if (deltaProperty.NameEquals("tool_calls")) - { - foreach (JsonElement toolCallElement in deltaProperty.Value.EnumerateArray()) - { - toolCallUpdates.Add( - StreamingToolCallUpdate.DeserializeStreamingToolCallUpdate(toolCallElement)); - } - } - } - } - if (choiceProperty.NameEquals("logprobs"u8)) - { - Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs - = Internal.Models.CreateChatCompletionResponseChoiceLogprobs.DeserializeCreateChatCompletionResponseChoiceLogprobs( - choiceProperty.Value); - logProbabilities = ChatLogProbabilityCollection.FromInternalData(internalLogprobs); - } - } - // In the unlikely event that more than one tool call arrives on a single chunk, we'll generate - // separate updates just like for choices. Adding a "null" if empty lets us avoid a separate loop. - if (toolCallUpdates.Count == 0) - { - toolCallUpdates.Add(null); - } - foreach (StreamingToolCallUpdate toolCallUpdate in toolCallUpdates) - { - results.Add(new StreamingChatUpdate( - id, - created, - systemFingerprint, - choiceIndex, - role, - contentUpdate, - finishReason, - functionName, - functionArgumentsUpdate, - toolCallUpdate, - logProbabilities)); - } - } - continue; - } - } - if (results.Count == 0) - { - results.Add(new StreamingChatUpdate(id, created, systemFingerprint)); - } - return results; + public BinaryData Write(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public StreamingChatUpdate Create(BinaryData data, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public string GetFormatFromOptions(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); } -} +} \ No newline at end of file diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs new file mode 100644 index 000000000..8b7466848 --- /dev/null +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs @@ -0,0 +1,177 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; +namespace OpenAI.Chat; + +internal class StreamingChatUpdateCollection : StreamedEventCollection +{ + internal static StreamingChatUpdateCollection DeserializeSseChatUpdates(ReadOnlyMemory _, JsonElement sseDataJson) + { + // TODO: would another enumerable implementation be more performant than list? + StreamingChatUpdateCollection results = []; + + // TODO: Do we need to validate that we didn't get null or empty? + // What's the contract for the JSON updates? + + if (sseDataJson.ValueKind == JsonValueKind.Null) + { + return results; + } + + string id = default; + DateTimeOffset created = default; + string systemFingerprint = null; + foreach (JsonProperty property in sseDataJson.EnumerateObject()) + { + if (property.NameEquals("id"u8)) + { + id = property.Value.GetString(); + continue; + } + if (property.NameEquals("created"u8)) + { + created = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + continue; + } + if (property.NameEquals("system_fingerprint")) + { + systemFingerprint = property.Value.GetString(); + continue; + } + if (property.NameEquals("choices"u8)) + { + foreach (JsonElement choiceElement in property.Value.EnumerateArray()) + { + ChatRole? role = null; + string contentUpdate = null; + string functionName = null; + string functionArgumentsUpdate = null; + int choiceIndex = 0; + ChatFinishReason? finishReason = null; + List toolCallUpdates = []; + ChatLogProbabilityCollection logProbabilities = null; + + foreach (JsonProperty choiceProperty in choiceElement.EnumerateObject()) + { + if (choiceProperty.NameEquals("index"u8)) + { + choiceIndex = choiceProperty.Value.GetInt32(); + continue; + } + if (choiceProperty.NameEquals("finish_reason"u8)) + { + if (choiceProperty.Value.ValueKind == JsonValueKind.Null) + { + finishReason = null; + continue; + } + finishReason = choiceProperty.Value.GetString() switch + { + "stop" => ChatFinishReason.Stopped, + "length" => ChatFinishReason.Length, + "tool_calls" => ChatFinishReason.ToolCalls, + "function_call" => ChatFinishReason.FunctionCall, + "content_filter" => ChatFinishReason.ContentFilter, + _ => throw new ArgumentException(nameof(finishReason)), + }; + continue; + } + if (choiceProperty.NameEquals("delta"u8)) + { + foreach (JsonProperty deltaProperty in choiceProperty.Value.EnumerateObject()) + { + if (deltaProperty.NameEquals("role"u8)) + { + role = deltaProperty.Value.GetString() switch + { + "system" => ChatRole.System, + "user" => ChatRole.User, + "assistant" => ChatRole.Assistant, + "tool" => ChatRole.Tool, + "function" => ChatRole.Function, + _ => throw new ArgumentException(nameof(role)), + }; + continue; + } + if (deltaProperty.NameEquals("content"u8)) + { + contentUpdate = deltaProperty.Value.GetString(); + continue; + } + if (deltaProperty.NameEquals("function_call"u8)) + { + foreach (JsonProperty functionProperty in deltaProperty.Value.EnumerateObject()) + { + if (functionProperty.NameEquals("name"u8)) + { + functionName = functionProperty.Value.GetString(); + continue; + } + if (functionProperty.NameEquals("arguments"u8)) + { + functionArgumentsUpdate = functionProperty.Value.GetString(); + } + } + } + if (deltaProperty.NameEquals("tool_calls")) + { + foreach (JsonElement toolCallElement in deltaProperty.Value.EnumerateArray()) + { + toolCallUpdates.Add( + StreamingToolCallUpdate.DeserializeStreamingToolCallUpdate(toolCallElement)); + } + } + } + } + if (choiceProperty.NameEquals("logprobs"u8)) + { + Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs + = Internal.Models.CreateChatCompletionResponseChoiceLogprobs.DeserializeCreateChatCompletionResponseChoiceLogprobs( + choiceProperty.Value); + logProbabilities = ChatLogProbabilityCollection.FromInternalData(internalLogprobs); + } + } + // In the unlikely event that more than one tool call arrives on a single chunk, we'll generate + // separate updates just like for choices. Adding a "null" if empty lets us avoid a separate loop. + if (toolCallUpdates.Count == 0) + { + toolCallUpdates.Add(null); + } + foreach (StreamingToolCallUpdate toolCallUpdate in toolCallUpdates) + { + results.Add(new StreamingChatUpdate( + id, + created, + systemFingerprint, + choiceIndex, + role, + contentUpdate, + finishReason, + functionName, + functionArgumentsUpdate, + toolCallUpdate, + logProbabilities)); + } + } + continue; + } + } + if (results.Count == 0) + { + results.Add(new StreamingChatUpdate(id, created, systemFingerprint)); + } + return results; + } + + public override StreamedEventCollection Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public override StreamedEventCollection Create(BinaryData data, ModelReaderWriterOptions options) + { + using JsonDocument doc = JsonDocument.Parse(data); + return DeserializeSseChatUpdates(null, doc.RootElement); + } +} diff --git a/.dotnet/src/Generated/Models/CreateRunRequest.Serialization.cs b/.dotnet/src/Generated/Models/CreateRunRequest.Serialization.cs index 15ade4fd0..c7cb96fc3 100644 --- a/.dotnet/src/Generated/Models/CreateRunRequest.Serialization.cs +++ b/.dotnet/src/Generated/Models/CreateRunRequest.Serialization.cs @@ -106,6 +106,11 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriter writer.WriteNull("metadata"); } } + if (Optional.IsDefined(Stream)) + { + writer.WritePropertyName("stream"u8); + writer.WriteBooleanValue(Stream.Value); + } if (options.Format != "W" && _serializedAdditionalRawData != null) { foreach (var item in _serializedAdditionalRawData) @@ -150,6 +155,7 @@ internal static CreateRunRequest DeserializeCreateRunRequest(JsonElement element string additionalInstructions = default; IList tools = default; IDictionary metadata = default; + bool? stream = default; IDictionary serializedAdditionalRawData = default; Dictionary additionalPropertiesDictionary = new Dictionary(); foreach (var property in element.EnumerateObject()) @@ -224,6 +230,15 @@ internal static CreateRunRequest DeserializeCreateRunRequest(JsonElement element metadata = dictionary; continue; } + if (property.NameEquals("stream"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + continue; + } + stream = property.Value.GetBoolean(); + continue; + } if (options.Format != "W") { additionalPropertiesDictionary.Add(property.Name, BinaryData.FromString(property.Value.GetRawText())); @@ -237,6 +252,7 @@ internal static CreateRunRequest DeserializeCreateRunRequest(JsonElement element additionalInstructions, tools ?? new ChangeTrackingList(), metadata ?? new ChangeTrackingDictionary(), + stream, serializedAdditionalRawData); } diff --git a/.dotnet/src/Generated/Models/CreateRunRequest.cs b/.dotnet/src/Generated/Models/CreateRunRequest.cs index db7f2f182..3e7e9c6d9 100644 --- a/.dotnet/src/Generated/Models/CreateRunRequest.cs +++ b/.dotnet/src/Generated/Models/CreateRunRequest.cs @@ -77,8 +77,12 @@ public CreateRunRequest(string assistantId) /// additional information about the object in a structured format. Keys can be a maximum of 64 /// characters long and values can be a maxium of 512 characters long. /// + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state with a data: [DONE] message. + /// /// Keeps track of any properties unknown to the library. - internal CreateRunRequest(string assistantId, string model, string instructions, string additionalInstructions, IList tools, IDictionary metadata, IDictionary serializedAdditionalRawData) + internal CreateRunRequest(string assistantId, string model, string instructions, string additionalInstructions, IList tools, IDictionary metadata, bool? stream, IDictionary serializedAdditionalRawData) { AssistantId = assistantId; Model = model; @@ -86,6 +90,7 @@ internal CreateRunRequest(string assistantId, string model, string instructions, AdditionalInstructions = additionalInstructions; Tools = tools; Metadata = metadata; + Stream = stream; _serializedAdditionalRawData = serializedAdditionalRawData; } @@ -150,5 +155,10 @@ internal CreateRunRequest() /// characters long and values can be a maxium of 512 characters long. /// public IDictionary Metadata { get; set; } + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state with a data: [DONE] message. + /// + public bool? Stream { get; set; } } } diff --git a/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.Serialization.cs b/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.Serialization.cs index 1613f3680..7fc949ee8 100644 --- a/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.Serialization.cs +++ b/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.Serialization.cs @@ -99,6 +99,11 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelRea writer.WriteNull("metadata"); } } + if (Optional.IsDefined(Stream)) + { + writer.WritePropertyName("stream"u8); + writer.WriteBooleanValue(Stream.Value); + } if (options.Format != "W" && _serializedAdditionalRawData != null) { foreach (var item in _serializedAdditionalRawData) @@ -143,6 +148,7 @@ internal static CreateThreadAndRunRequest DeserializeCreateThreadAndRunRequest(J string instructions = default; IList tools = default; IDictionary metadata = default; + bool? stream = default; IDictionary serializedAdditionalRawData = default; Dictionary additionalPropertiesDictionary = new Dictionary(); foreach (var property in element.EnumerateObject()) @@ -216,6 +222,15 @@ internal static CreateThreadAndRunRequest DeserializeCreateThreadAndRunRequest(J metadata = dictionary; continue; } + if (property.NameEquals("stream"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + continue; + } + stream = property.Value.GetBoolean(); + continue; + } if (options.Format != "W") { additionalPropertiesDictionary.Add(property.Name, BinaryData.FromString(property.Value.GetRawText())); @@ -229,6 +244,7 @@ internal static CreateThreadAndRunRequest DeserializeCreateThreadAndRunRequest(J instructions, tools ?? new ChangeTrackingList(), metadata ?? new ChangeTrackingDictionary(), + stream, serializedAdditionalRawData); } diff --git a/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.cs b/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.cs index eb2b0b0ba..3fc793df3 100644 --- a/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.cs +++ b/.dotnet/src/Generated/Models/CreateThreadAndRunRequest.cs @@ -74,8 +74,12 @@ public CreateThreadAndRunRequest(string assistantId) /// additional information about the object in a structured format. Keys can be a maximum of 64 /// characters long and values can be a maxium of 512 characters long. /// + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state with a data: [DONE] message. + /// /// Keeps track of any properties unknown to the library. - internal CreateThreadAndRunRequest(string assistantId, CreateThreadRequest thread, string model, string instructions, IList tools, IDictionary metadata, IDictionary serializedAdditionalRawData) + internal CreateThreadAndRunRequest(string assistantId, CreateThreadRequest thread, string model, string instructions, IList tools, IDictionary metadata, bool? stream, IDictionary serializedAdditionalRawData) { AssistantId = assistantId; Thread = thread; @@ -83,6 +87,7 @@ internal CreateThreadAndRunRequest(string assistantId, CreateThreadRequest threa Instructions = instructions; Tools = tools; Metadata = metadata; + Stream = stream; _serializedAdditionalRawData = serializedAdditionalRawData; } @@ -144,5 +149,10 @@ internal CreateThreadAndRunRequest() /// characters long and values can be a maxium of 512 characters long. /// public IDictionary Metadata { get; set; } + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state with a data: [DONE] message. + /// + public bool? Stream { get; set; } } } diff --git a/.dotnet/src/Generated/Models/MessageObject.Serialization.cs b/.dotnet/src/Generated/Models/MessageObject.Serialization.cs index 5c0cb8c53..0bbbc7db0 100644 --- a/.dotnet/src/Generated/Models/MessageObject.Serialization.cs +++ b/.dotnet/src/Generated/Models/MessageObject.Serialization.cs @@ -29,6 +29,35 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOpt writer.WriteNumberValue(CreatedAt, "U"); writer.WritePropertyName("thread_id"u8); writer.WriteStringValue(ThreadId); + writer.WritePropertyName("status"u8); + writer.WriteStringValue(Status.ToString()); + if (IncompleteDetails != null) + { + writer.WritePropertyName("incomplete_details"u8); + writer.WriteObjectValue(IncompleteDetails); + } + else + { + writer.WriteNull("incomplete_details"); + } + if (CompletedAt != null) + { + writer.WritePropertyName("completed_at"u8); + writer.WriteStringValue(CompletedAt.Value, "O"); + } + else + { + writer.WriteNull("completed_at"); + } + if (IncompleteAt != null) + { + writer.WritePropertyName("incomplete_at"u8); + writer.WriteStringValue(IncompleteAt.Value, "O"); + } + else + { + writer.WriteNull("incomplete_at"); + } writer.WritePropertyName("role"u8); writer.WriteStringValue(Role.ToString()); writer.WritePropertyName("content"u8); @@ -132,6 +161,10 @@ internal static MessageObject DeserializeMessageObject(JsonElement element, Mode MessageObjectObject @object = default; DateTimeOffset createdAt = default; string threadId = default; + MessageObjectStatus status = default; + MessageObjectIncompleteDetails incompleteDetails = default; + DateTimeOffset? completedAt = default; + DateTimeOffset? incompleteAt = default; MessageObjectRole role = default; IReadOnlyList content = default; string assistantId = default; @@ -162,6 +195,45 @@ internal static MessageObject DeserializeMessageObject(JsonElement element, Mode threadId = property.Value.GetString(); continue; } + if (property.NameEquals("status"u8)) + { + status = new MessageObjectStatus(property.Value.GetString()); + continue; + } + if (property.NameEquals("incomplete_details"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + incompleteDetails = null; + continue; + } + incompleteDetails = MessageObjectIncompleteDetails.DeserializeMessageObjectIncompleteDetails(property.Value, options); + continue; + } + if (property.NameEquals("completed_at"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + completedAt = null; + continue; + } + // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // completedAt = property.Value.GetDateTimeOffset("O"); + completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + continue; + } + if (property.NameEquals("incomplete_at"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + incompleteAt = null; + continue; + } + // BUG: https://github.com/Azure/autorest.csharp/issues/4296 + // completedAt = property.Value.GetDateTimeOffset("O"); + completedAt = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + continue; + } if (property.NameEquals("role"u8)) { role = new MessageObjectRole(property.Value.GetString()); @@ -240,6 +312,10 @@ internal static MessageObject DeserializeMessageObject(JsonElement element, Mode @object, createdAt, threadId, + status, + incompleteDetails, + completedAt, + incompleteAt, role, content, assistantId, diff --git a/.dotnet/src/Generated/Models/MessageObject.cs b/.dotnet/src/Generated/Models/MessageObject.cs index e314eaf06..ae218e713 100644 --- a/.dotnet/src/Generated/Models/MessageObject.cs +++ b/.dotnet/src/Generated/Models/MessageObject.cs @@ -46,6 +46,10 @@ internal partial class MessageObject /// The identifier, which can be referenced in API endpoints. /// The Unix timestamp (in seconds) for when the message was created. /// The [thread](/docs/api-reference/threads) ID that this message belongs to. + /// The status of the message, which can be either in_progress, incomplete, or completed. + /// On an incomplete message, details about why the message is incomplete. + /// The Unix timestamp at which the message was completed. + /// The Unix timestamp at which the message was marked as incomplete. /// The entity that produced the message. One of `user` or `assistant`. /// The content of the message in array of text and/or images. /// @@ -67,7 +71,7 @@ internal partial class MessageObject /// characters long and values can be a maxium of 512 characters long. /// /// , , or is null. - internal MessageObject(string id, DateTimeOffset createdAt, string threadId, MessageObjectRole role, IEnumerable content, string assistantId, string runId, IEnumerable fileIds, IReadOnlyDictionary metadata) + internal MessageObject(string id, DateTimeOffset createdAt, string threadId, MessageObjectStatus status, MessageObjectIncompleteDetails incompleteDetails, DateTimeOffset? completedAt, DateTimeOffset? incompleteAt, MessageObjectRole role, IEnumerable content, string assistantId, string runId, IEnumerable fileIds, IReadOnlyDictionary metadata) { Argument.AssertNotNull(id, nameof(id)); Argument.AssertNotNull(threadId, nameof(threadId)); @@ -77,6 +81,10 @@ internal MessageObject(string id, DateTimeOffset createdAt, string threadId, Mes Id = id; CreatedAt = createdAt; ThreadId = threadId; + Status = status; + IncompleteDetails = incompleteDetails; + CompletedAt = completedAt; + IncompleteAt = incompleteAt; Role = role; Content = content.ToList(); AssistantId = assistantId; @@ -90,6 +98,10 @@ internal MessageObject(string id, DateTimeOffset createdAt, string threadId, Mes /// The object type, which is always `thread.message`. /// The Unix timestamp (in seconds) for when the message was created. /// The [thread](/docs/api-reference/threads) ID that this message belongs to. + /// The status of the message, which can be either in_progress, incomplete, or completed. + /// On an incomplete message, details about why the message is incomplete. + /// The Unix timestamp at which the message was completed. + /// The Unix timestamp at which the message was marked as incomplete. /// The entity that produced the message. One of `user` or `assistant`. /// The content of the message in array of text and/or images. /// @@ -111,12 +123,16 @@ internal MessageObject(string id, DateTimeOffset createdAt, string threadId, Mes /// characters long and values can be a maxium of 512 characters long. /// /// Keeps track of any properties unknown to the library. - internal MessageObject(string id, MessageObjectObject @object, DateTimeOffset createdAt, string threadId, MessageObjectRole role, IReadOnlyList content, string assistantId, string runId, IReadOnlyList fileIds, IReadOnlyDictionary metadata, IDictionary serializedAdditionalRawData) + internal MessageObject(string id, MessageObjectObject @object, DateTimeOffset createdAt, string threadId, MessageObjectStatus status, MessageObjectIncompleteDetails incompleteDetails, DateTimeOffset? completedAt, DateTimeOffset? incompleteAt, MessageObjectRole role, IReadOnlyList content, string assistantId, string runId, IReadOnlyList fileIds, IReadOnlyDictionary metadata, IDictionary serializedAdditionalRawData) { Id = id; Object = @object; CreatedAt = createdAt; ThreadId = threadId; + Status = status; + IncompleteDetails = incompleteDetails; + CompletedAt = completedAt; + IncompleteAt = incompleteAt; Role = role; Content = content; AssistantId = assistantId; @@ -140,6 +156,14 @@ internal MessageObject() public DateTimeOffset CreatedAt { get; } /// The [thread](/docs/api-reference/threads) ID that this message belongs to. public string ThreadId { get; } + /// The status of the message, which can be either in_progress, incomplete, or completed. + public MessageObjectStatus Status { get; } + /// On an incomplete message, details about why the message is incomplete. + public MessageObjectIncompleteDetails IncompleteDetails { get; } + /// The Unix timestamp at which the message was completed. + public DateTimeOffset? CompletedAt { get; } + /// The Unix timestamp at which the message was marked as incomplete. + public DateTimeOffset? IncompleteAt { get; } /// The entity that produced the message. One of `user` or `assistant`. public MessageObjectRole Role { get; } /// diff --git a/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.Serialization.cs b/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.Serialization.cs new file mode 100644 index 000000000..d7683486e --- /dev/null +++ b/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.Serialization.cs @@ -0,0 +1,130 @@ +// + +using System; +using OpenAI.ClientShared.Internal; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; +using OpenAI; + +namespace OpenAI.Internal.Models +{ + internal partial class MessageObjectIncompleteDetails : IJsonModel + { + void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(MessageObjectIncompleteDetails)} does not support '{format}' format."); + } + + writer.WriteStartObject(); + writer.WritePropertyName("reason"u8); + writer.WriteStringValue(Reason); + if (options.Format != "W" && _serializedAdditionalRawData != null) + { + foreach (var item in _serializedAdditionalRawData) + { + writer.WritePropertyName(item.Key); +#if NET6_0_OR_GREATER + writer.WriteRawValue(item.Value); +#else + using (JsonDocument document = JsonDocument.Parse(item.Value)) + { + JsonSerializer.Serialize(writer, document.RootElement); + } +#endif + } + } + writer.WriteEndObject(); + } + + MessageObjectIncompleteDetails IJsonModel.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(MessageObjectIncompleteDetails)} does not support '{format}' format."); + } + + using JsonDocument document = JsonDocument.ParseValue(ref reader); + return DeserializeMessageObjectIncompleteDetails(document.RootElement, options); + } + + internal static MessageObjectIncompleteDetails DeserializeMessageObjectIncompleteDetails(JsonElement element, ModelReaderWriterOptions options = null) + { + options ??= new ModelReaderWriterOptions("W"); + + if (element.ValueKind == JsonValueKind.Null) + { + return null; + } + string reason = default; + IDictionary serializedAdditionalRawData = default; + Dictionary additionalPropertiesDictionary = new Dictionary(); + foreach (var property in element.EnumerateObject()) + { + if (property.NameEquals("reason"u8)) + { + reason = property.Value.GetString(); + continue; + } + if (options.Format != "W") + { + additionalPropertiesDictionary.Add(property.Name, BinaryData.FromString(property.Value.GetRawText())); + } + } + serializedAdditionalRawData = additionalPropertiesDictionary; + return new MessageObjectIncompleteDetails(reason, serializedAdditionalRawData); + } + + BinaryData IPersistableModel.Write(ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + + switch (format) + { + case "J": + return ModelReaderWriter.Write(this, options); + default: + throw new FormatException($"The model {nameof(MessageObjectIncompleteDetails)} does not support '{options.Format}' format."); + } + } + + MessageObjectIncompleteDetails IPersistableModel.Create(BinaryData data, ModelReaderWriterOptions options) + { + var format = options.Format == "W" ? ((IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + + switch (format) + { + case "J": + { + using JsonDocument document = JsonDocument.Parse(data); + return DeserializeMessageObjectIncompleteDetails(document.RootElement, options); + } + default: + throw new FormatException($"The model {nameof(MessageObjectIncompleteDetails)} does not support '{options.Format}' format."); + } + } + + string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; + + /// Deserializes the model from a raw response. + /// The result to deserialize the model from. + internal static MessageObjectIncompleteDetails FromResponse(PipelineResponse response) + { + using var document = JsonDocument.Parse(response.Content); + return DeserializeMessageObjectIncompleteDetails(document.RootElement); + } + + /// Convert into a Utf8JsonRequestBody. + internal virtual BinaryContent ToRequestBody() + { + var content = new Utf8JsonRequestBody(); + content.JsonWriter.WriteObjectValue(this); + return content; + } + } +} diff --git a/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.cs b/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.cs new file mode 100644 index 000000000..4033843d0 --- /dev/null +++ b/.dotnet/src/Generated/Models/MessageObjectIncompleteDetails.cs @@ -0,0 +1,71 @@ +// + +using System; +using System.Collections.Generic; +using OpenAI; + +namespace OpenAI.Internal.Models +{ + /// The MessageObjectIncompleteDetails. + internal partial class MessageObjectIncompleteDetails + { + /// + /// Keeps track of any properties unknown to the library. + /// + /// To assign an object to the value of this property use . + /// + /// + /// To assign an already formatted json string to this property use . + /// + /// + /// Examples: + /// + /// + /// BinaryData.FromObjectAsJson("foo") + /// Creates a payload of "foo". + /// + /// + /// BinaryData.FromString("\"foo\"") + /// Creates a payload of "foo". + /// + /// + /// BinaryData.FromObjectAsJson(new { key = "value" }) + /// Creates a payload of { "key": "value" }. + /// + /// + /// BinaryData.FromString("{\"key\": \"value\"}") + /// Creates a payload of { "key": "value" }. + /// + /// + /// + /// + private IDictionary _serializedAdditionalRawData; + + /// Initializes a new instance of . + /// The reason the message is incomplete. + /// is null. + internal MessageObjectIncompleteDetails(string reason) + { + Argument.AssertNotNull(reason, nameof(reason)); + + Reason = reason; + } + + /// Initializes a new instance of . + /// The reason the message is incomplete. + /// Keeps track of any properties unknown to the library. + internal MessageObjectIncompleteDetails(string reason, IDictionary serializedAdditionalRawData) + { + Reason = reason; + _serializedAdditionalRawData = serializedAdditionalRawData; + } + + /// Initializes a new instance of for deserialization. + internal MessageObjectIncompleteDetails() + { + } + + /// The reason the message is incomplete. + public string Reason { get; } + } +} diff --git a/.dotnet/src/Generated/Models/MessageObjectStatus.cs b/.dotnet/src/Generated/Models/MessageObjectStatus.cs new file mode 100644 index 000000000..b1ca494f3 --- /dev/null +++ b/.dotnet/src/Generated/Models/MessageObjectStatus.cs @@ -0,0 +1,49 @@ +// + +using System; +using System.ComponentModel; + +namespace OpenAI.Internal.Models +{ + /// Enum for status in MessageObject. + internal readonly partial struct MessageObjectStatus : IEquatable + { + private readonly string _value; + + /// Initializes a new instance of . + /// is null. + public MessageObjectStatus(string value) + { + _value = value ?? throw new ArgumentNullException(nameof(value)); + } + + private const string InProgressValue = "in_progress"; + private const string IncompleteValue = "incomplete"; + private const string CompletedValue = "completed"; + + /// in_progress. + public static MessageObjectStatus InProgress { get; } = new MessageObjectStatus(InProgressValue); + /// incomplete. + public static MessageObjectStatus Incomplete { get; } = new MessageObjectStatus(IncompleteValue); + /// completed. + public static MessageObjectStatus Completed { get; } = new MessageObjectStatus(CompletedValue); + /// Determines if two values are the same. + public static bool operator ==(MessageObjectStatus left, MessageObjectStatus right) => left.Equals(right); + /// Determines if two values are not the same. + public static bool operator !=(MessageObjectStatus left, MessageObjectStatus right) => !left.Equals(right); + /// Converts a string to a . + public static implicit operator MessageObjectStatus(string value) => new MessageObjectStatus(value); + + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public override bool Equals(object obj) => obj is MessageObjectStatus other && Equals(other); + /// + public bool Equals(MessageObjectStatus other) => string.Equals(_value, other._value, StringComparison.InvariantCultureIgnoreCase); + + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public override int GetHashCode() => _value?.GetHashCode() ?? 0; + /// + public override string ToString() => _value; + } +} diff --git a/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.Serialization.cs b/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.Serialization.cs index d38502d9b..be9df651c 100644 --- a/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.Serialization.cs +++ b/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.Serialization.cs @@ -28,6 +28,11 @@ void IJsonModel.Write(Utf8JsonWriter writer, ModelR writer.WriteObjectValue(item); } writer.WriteEndArray(); + if (Optional.IsDefined(Stream)) + { + writer.WritePropertyName("stream"u8); + writer.WriteBooleanValue(Stream.Value); + } if (options.Format != "W" && _serializedAdditionalRawData != null) { foreach (var item in _serializedAdditionalRawData) @@ -67,6 +72,7 @@ internal static SubmitToolOutputsRunRequest DeserializeSubmitToolOutputsRunReque return null; } IList toolOutputs = default; + bool? stream = default; IDictionary serializedAdditionalRawData = default; Dictionary additionalPropertiesDictionary = new Dictionary(); foreach (var property in element.EnumerateObject()) @@ -81,13 +87,22 @@ internal static SubmitToolOutputsRunRequest DeserializeSubmitToolOutputsRunReque toolOutputs = array; continue; } + if (property.NameEquals("stream"u8)) + { + if (property.Value.ValueKind == JsonValueKind.Null) + { + continue; + } + stream = property.Value.GetBoolean(); + continue; + } if (options.Format != "W") { additionalPropertiesDictionary.Add(property.Name, BinaryData.FromString(property.Value.GetRawText())); } } serializedAdditionalRawData = additionalPropertiesDictionary; - return new SubmitToolOutputsRunRequest(toolOutputs, serializedAdditionalRawData); + return new SubmitToolOutputsRunRequest(toolOutputs, stream, serializedAdditionalRawData); } BinaryData IPersistableModel.Write(ModelReaderWriterOptions options) diff --git a/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.cs b/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.cs index 2bed85b35..a4395b14c 100644 --- a/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.cs +++ b/.dotnet/src/Generated/Models/SubmitToolOutputsRunRequest.cs @@ -54,10 +54,15 @@ public SubmitToolOutputsRunRequest(IEnumerable Initializes a new instance of . /// A list of tools for which the outputs are being submitted. + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state. + /// /// Keeps track of any properties unknown to the library. - internal SubmitToolOutputsRunRequest(IList toolOutputs, IDictionary serializedAdditionalRawData) + internal SubmitToolOutputsRunRequest(IList toolOutputs, bool? stream, IDictionary serializedAdditionalRawData) { ToolOutputs = toolOutputs; + Stream = stream; _serializedAdditionalRawData = serializedAdditionalRawData; } @@ -68,5 +73,10 @@ internal SubmitToolOutputsRunRequest() /// A list of tools for which the outputs are being submitted. public IList ToolOutputs { get; } + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state. + /// + public bool? Stream { get; set; } } } diff --git a/.dotnet/src/Generated/OpenAIModelFactory.cs b/.dotnet/src/Generated/OpenAIModelFactory.cs index dcf7221bb..749081bca 100644 --- a/.dotnet/src/Generated/OpenAIModelFactory.cs +++ b/.dotnet/src/Generated/OpenAIModelFactory.cs @@ -416,6 +416,10 @@ public static CreateMessageRequest CreateMessageRequest(CreateMessageRequestRole /// The object type, which is always `thread.message`. /// The Unix timestamp (in seconds) for when the message was created. /// The [thread](/docs/api-reference/threads) ID that this message belongs to. + /// The status of the message, which can be either in_progress, incomplete, or completed. + /// On an incomplete message, details about why the message is incomplete. + /// The Unix timestamp at which the message was completed. + /// The Unix timestamp at which the message was marked as incomplete. /// The entity that produced the message. One of `user` or `assistant`. /// The content of the message in array of text and/or images. /// @@ -437,7 +441,7 @@ public static CreateMessageRequest CreateMessageRequest(CreateMessageRequestRole /// characters long and values can be a maxium of 512 characters long. /// /// A new instance for mocking. - public static MessageObject MessageObject(string id = null, MessageObjectObject @object = default, DateTimeOffset createdAt = default, string threadId = null, MessageObjectRole role = default, IEnumerable content = null, string assistantId = null, string runId = null, IEnumerable fileIds = null, IReadOnlyDictionary metadata = null) + public static MessageObject MessageObject(string id = null, MessageObjectObject @object = default, DateTimeOffset createdAt = default, string threadId = null, MessageObjectStatus status = default, MessageObjectIncompleteDetails incompleteDetails = null, DateTimeOffset? completedAt = null, DateTimeOffset? incompleteAt = null, MessageObjectRole role = default, IEnumerable content = null, string assistantId = null, string runId = null, IEnumerable fileIds = null, IReadOnlyDictionary metadata = null) { content ??= new List(); fileIds ??= new List(); @@ -448,6 +452,10 @@ public static MessageObject MessageObject(string id = null, MessageObjectObject @object, createdAt, threadId, + status, + incompleteDetails, + completedAt, + incompleteAt, role, content?.ToList(), assistantId, @@ -457,6 +465,14 @@ public static MessageObject MessageObject(string id = null, MessageObjectObject serializedAdditionalRawData: null); } + /// Initializes a new instance of . + /// The reason the message is incomplete. + /// A new instance for mocking. + public static MessageObjectIncompleteDetails MessageObjectIncompleteDetails(string reason = null) + { + return new MessageObjectIncompleteDetails(reason, serializedAdditionalRawData: null); + } + /// Initializes a new instance of . /// /// @@ -561,8 +577,12 @@ public static DeleteModelResponse DeleteModelResponse(string id = null, bool del /// additional information about the object in a structured format. Keys can be a maximum of 64 /// characters long and values can be a maxium of 512 characters long. /// + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state with a data: [DONE] message. + /// /// A new instance for mocking. - public static CreateThreadAndRunRequest CreateThreadAndRunRequest(string assistantId = null, CreateThreadRequest thread = null, string model = null, string instructions = null, IEnumerable tools = null, IDictionary metadata = null) + public static CreateThreadAndRunRequest CreateThreadAndRunRequest(string assistantId = null, CreateThreadRequest thread = null, string model = null, string instructions = null, IEnumerable tools = null, IDictionary metadata = null, bool? stream = null) { tools ??= new List(); metadata ??= new Dictionary(); @@ -574,6 +594,7 @@ public static CreateThreadAndRunRequest CreateThreadAndRunRequest(string assista instructions, tools?.ToList(), metadata, + stream, serializedAdditionalRawData: null); } @@ -727,8 +748,12 @@ public static RunCompletionUsage RunCompletionUsage(long completionTokens = defa /// additional information about the object in a structured format. Keys can be a maximum of 64 /// characters long and values can be a maxium of 512 characters long. /// + /// + /// If true, returns a stream of events that happen during the Run as server-sent events, + /// terminating when the Run enters a terminal state with a data: [DONE] message. + /// /// A new instance for mocking. - public static CreateRunRequest CreateRunRequest(string assistantId = null, string model = null, string instructions = null, string additionalInstructions = null, IEnumerable tools = null, IDictionary metadata = null) + public static CreateRunRequest CreateRunRequest(string assistantId = null, string model = null, string instructions = null, string additionalInstructions = null, IEnumerable tools = null, IDictionary metadata = null, bool? stream = null) { tools ??= new List(); metadata ??= new Dictionary(); @@ -740,6 +765,7 @@ public static CreateRunRequest CreateRunRequest(string assistantId = null, strin additionalInstructions, tools?.ToList(), metadata, + stream, serializedAdditionalRawData: null); } diff --git a/.dotnet/src/OpenAI.csproj b/.dotnet/src/OpenAI.csproj index cd5fa6d55..0fad9d5de 100644 --- a/.dotnet/src/OpenAI.csproj +++ b/.dotnet/src/OpenAI.csproj @@ -2,6 +2,7 @@ This is the OpenAI client library for developing .NET applications with rich experience. SDK Code Generation OpenAI + 1.0.0-beta.1 OpenAI netstandard2.0 latest diff --git a/.dotnet/src/Utility/ServerSentEvent.cs b/.dotnet/src/Utility/ServerSentEvent.cs new file mode 100644 index 000000000..ea91b1889 --- /dev/null +++ b/.dotnet/src/Utility/ServerSentEvent.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace OpenAI; + +// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +internal readonly struct ServerSentEvent +{ + // Gets the value of the SSE "event type" buffer, used to distinguish between event kinds. + public ReadOnlyMemory EventName { get; } + // Gets the value of the SSE "data" buffer, which holds the payload of the server-sent event. + public ReadOnlyMemory Data { get; } + // Gets the value of the "last event ID" buffer, with which a user agent can reestablish a session. + public ReadOnlyMemory LastEventId { get; } + // If present, gets the defined "retry" value for the event, which represents the delay before reconnecting. + public TimeSpan? ReconnectionTime { get; } + + private readonly IReadOnlyList _fields; + private readonly string _multiLineData; + + internal ServerSentEvent(IReadOnlyList fields) + { + _fields = fields; + StringBuilder multiLineDataBuilder = null; + for (int i = 0; i < _fields.Count; i++) + { + ReadOnlyMemory fieldValue = _fields[i].Value; + switch (_fields[i].FieldType) + { + case ServerSentEventFieldKind.Event: + EventName = fieldValue; + break; + case ServerSentEventFieldKind.Data: + { + if (multiLineDataBuilder != null) + { + multiLineDataBuilder.Append(fieldValue); + } + else if (Data.IsEmpty) + { + Data = fieldValue; + } + else + { + multiLineDataBuilder ??= new(); + multiLineDataBuilder.Append(fieldValue); + Data = null; + } + break; + } + case ServerSentEventFieldKind.Id: + LastEventId = fieldValue; + break; + case ServerSentEventFieldKind.Retry: + ReconnectionTime = Int32.TryParse(fieldValue.ToString(), out int retry) ? TimeSpan.FromMilliseconds(retry) : null; + break; + default: + break; + } + if (multiLineDataBuilder != null) + { + _multiLineData = multiLineDataBuilder.ToString(); + Data = _multiLineData.AsMemory(); + } + } + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/ServerSentEventField.cs b/.dotnet/src/Utility/ServerSentEventField.cs new file mode 100644 index 000000000..c6032c931 --- /dev/null +++ b/.dotnet/src/Utility/ServerSentEventField.cs @@ -0,0 +1,64 @@ +using System; + +namespace OpenAI; + +// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +internal readonly struct ServerSentEventField +{ + public ServerSentEventFieldKind FieldType { get; } + + // TODO: we should not expose UTF16 publicly + public ReadOnlyMemory Value + { + get + { + if (_valueStartIndex >= _original.Length) + { + return ReadOnlyMemory.Empty; + } + else + { + return _original.AsMemory(_valueStartIndex); + } + } + } + + private readonly string _original; + private readonly int _valueStartIndex; + + internal ServerSentEventField(string line) + { + _original = line; + int colonIndex = _original.AsSpan().IndexOf(':'); + + ReadOnlyMemory fieldName = colonIndex < 0 ? _original.AsMemory(): _original.AsMemory(0, colonIndex); + FieldType = fieldName.Span switch + { + var x when x.SequenceEqual(s_eventFieldName.Span) => ServerSentEventFieldKind.Event, + var x when x.SequenceEqual(s_dataFieldName.Span) => ServerSentEventFieldKind.Data, + var x when x.SequenceEqual(s_lastEventIdFieldName.Span) => ServerSentEventFieldKind.Id, + var x when x.SequenceEqual(s_retryFieldName.Span) => ServerSentEventFieldKind.Retry, + _ => ServerSentEventFieldKind.Ignored, + }; + + if (colonIndex < 0) + { + _valueStartIndex = _original.Length; + } + else if (colonIndex + 1 < _original.Length && _original[colonIndex + 1] == ' ') + { + _valueStartIndex = colonIndex + 2; + } + else + { + _valueStartIndex = colonIndex + 1; + } + } + + public override string ToString() => _original; + + private static readonly ReadOnlyMemory s_eventFieldName = "event".AsMemory(); + private static readonly ReadOnlyMemory s_dataFieldName = "data".AsMemory(); + private static readonly ReadOnlyMemory s_lastEventIdFieldName = "id".AsMemory(); + private static readonly ReadOnlyMemory s_retryFieldName = "retry".AsMemory(); +} \ No newline at end of file diff --git a/.dotnet/src/Utility/ServerSentEventFieldKind.cs b/.dotnet/src/Utility/ServerSentEventFieldKind.cs new file mode 100644 index 000000000..c3597b0ff --- /dev/null +++ b/.dotnet/src/Utility/ServerSentEventFieldKind.cs @@ -0,0 +1,11 @@ +namespace OpenAI; + +// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +internal enum ServerSentEventFieldKind +{ + Event, + Data, + Id, + Retry, + Ignored +} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseAsyncEnumerator.cs b/.dotnet/src/Utility/SseAsyncEnumerator.cs index 743a1bedd..7588136b3 100644 --- a/.dotnet/src/Utility/SseAsyncEnumerator.cs +++ b/.dotnet/src/Utility/SseAsyncEnumerator.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading; @@ -9,51 +10,37 @@ namespace OpenAI; internal static class SseAsyncEnumerator { - internal static async IAsyncEnumerable EnumerateFromSseStream( + private static ReadOnlyMemory[] _wellKnownTokens = + [ + "[DONE]".AsMemory(), + ]; + + internal static async IAsyncEnumerable EnumerateFromSseJsonStream( Stream stream, - Func> multiElementDeserializer, + Func, JsonElement, IEnumerable> multiElementDeserializer, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - try + using AsyncSseReader reader = new AsyncSseReader(stream); + + await foreach (ServerSentEvent sseEvent in reader.GetEventsAsync(cancellationToken)) { - using SseReader sseReader = new(stream); - while (!cancellationToken.IsCancellationRequested) + // TODO: does `continue` mean we keep reading from the Stream after + // the [DONE] event? If so, figure out if this is what we want here. + if (IsWellKnownDoneToken(sseEvent.Data)) continue; + + // TODO: Make faster with Utf8JsonReader, IModel? + using JsonDocument sseDocument = JsonDocument.Parse(sseEvent.Data); + + foreach (T item in multiElementDeserializer(sseEvent.EventName, sseDocument.RootElement)) { - SseLine? sseEvent = await sseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); - if (sseEvent is not null) - { - ReadOnlyMemory name = sseEvent.Value.FieldName; - if (!name.Span.SequenceEqual("data".AsSpan())) - { - throw new InvalidDataException(); - } - ReadOnlyMemory value = sseEvent.Value.FieldValue; - if (value.Span.SequenceEqual("[DONE]".AsSpan())) - { - break; - } - using JsonDocument sseMessageJson = JsonDocument.Parse(value); - IEnumerable newItems = multiElementDeserializer.Invoke(sseMessageJson.RootElement); - foreach (T item in newItems) - { - yield return item; - } - } + yield return item; } } - finally - { - // Always dispose the stream immediately once enumeration is complete for any reason - stream.Dispose(); - } } - internal static IAsyncEnumerable EnumerateFromSseStream( - Stream stream, - Func elementDeserializer, - CancellationToken cancellationToken = default) - => EnumerateFromSseStream( - stream, - (element) => new T[] { elementDeserializer.Invoke(element) }, - cancellationToken); + private static bool IsWellKnownDoneToken(ReadOnlyMemory data) + { + // TODO: Make faster than LINQ. + return _wellKnownTokens.Any(token => data.Span.SequenceEqual(token.Span)); + } } \ No newline at end of file diff --git a/.dotnet/src/Utility/SseLine.cs b/.dotnet/src/Utility/SseLine.cs deleted file mode 100644 index 4d82315f9..000000000 --- a/.dotnet/src/Utility/SseLine.cs +++ /dev/null @@ -1,29 +0,0 @@ -using System; - -namespace OpenAI; - -// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream -internal readonly struct SseLine -{ - private readonly string _original; - private readonly int _colonIndex; - private readonly int _valueIndex; - - public static SseLine Empty { get; } = new SseLine(string.Empty, 0, false); - - internal SseLine(string original, int colonIndex, bool hasSpaceAfterColon) - { - _original = original; - _colonIndex = colonIndex; - _valueIndex = colonIndex + (hasSpaceAfterColon ? 2 : 1); - } - - public bool IsEmpty => _original.Length == 0; - public bool IsComment => !IsEmpty && _original[0] == ':'; - - // TODO: we should not expose UTF16 publicly - public ReadOnlyMemory FieldName => _original.AsMemory(0, _colonIndex); - public ReadOnlyMemory FieldValue => _original.AsMemory(_valueIndex); - - public override string ToString() => _original; -} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseReader.cs b/.dotnet/src/Utility/SseReader.cs index cf0301408..cdefe3b3d 100644 --- a/.dotnet/src/Utility/SseReader.cs +++ b/.dotnet/src/Utility/SseReader.cs @@ -1,118 +1,207 @@ using System; -using System.ClientModel; -using System.ClientModel.Internal; +using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace OpenAI; -internal sealed class SseReader : IDisposable +#nullable enable + +// TODO: Can this type move its Dispose implementation into the +// Enumerator? +internal sealed class AsyncSseReader : IAsyncEnumerable, IDisposable, IAsyncDisposable +{ + private readonly Stream _stream; + + private bool _disposedValue; + + public AsyncSseReader(Stream stream) + { + _stream = stream; + } + + // TODO: Provide sync version + //// TODO: reuse code across sync and async. + + ///// + ///// Synchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + ///// available and returning null once no further data is present on the stream. + ///// + ///// An optional cancellation token that can abort subsequent reads. + ///// + ///// The next in the stream, or null once no more data can be read from the stream. + ///// + //public IEnumerable GetEvents(CancellationToken cancellationToken = default) + //{ + // List fields = []; + + // while (!cancellationToken.IsCancellationRequested) + // { + // string line = _reader.ReadLine(); + // if (line == null) + // { + // // A null line indicates end of input + // yield break; + // } + // else if (line.Length == 0) + // { + // // An empty line should dispatch an event for pending accumulated fields + // ServerSentEvent nextEvent = new(fields); + // fields = []; + // yield return nextEvent; + // } + // else if (line[0] == ':') + // { + // // A line beginning with a colon is a comment and should be ignored + // continue; + // } + // else + // { + // // Otherwise, process the the field + value and accumulate it for the next dispatched event + // fields.Add(new ServerSentEventField(line)); + // } + // } + + // yield break; + //} + + ///// + ///// Asynchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + ///// available and returning null once no further data is present on the stream. + ///// + ///// An optional cancellation token that can abort subsequent reads. + ///// + ///// The next in the stream, or null once no more data can be read from the stream. + ///// + //public async IAsyncEnumerable GetEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + //{ + // List fields = []; + + // while (!cancellationToken.IsCancellationRequested) + // { + // string line = await _reader.ReadLineAsync().ConfigureAwait(false); + // if (line == null) + // { + // // A null line indicates end of input + // yield break; + // } + // else if (line.Length == 0) + // { + // // An empty line should dispatch an event for pending accumulated fields + // ServerSentEvent nextEvent = new(fields); + // fields = []; + // yield return nextEvent; + // } + // else if (line[0] == ':') + // { + // // A line beginning with a colon is a comment and should be ignored + // continue; + // } + // else + // { + // // Otherwise, process the the field + value and accumulate it for the next dispatched event + // fields.Add(new ServerSentEventField(line)); + // } + // } + + // yield break; + //} + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new AsyncSseReaderEnumerator(_stream); + } + + private class AsyncSseReaderEnumerator : IAsyncEnumerator { private readonly Stream _stream; private readonly StreamReader _reader; - private bool _disposedValue; - public SseReader(Stream stream) + private ServerSentEvent? _current; + + public AsyncSseReaderEnumerator(Stream stream) { _stream = stream; _reader = new StreamReader(stream); } - public SseLine? TryReadSingleFieldEvent() - { - while (true) - { - SseLine? line = TryReadLine(); - if (line == null) - return null; - if (line.Value.IsEmpty) - throw new InvalidDataException("event expected."); - SseLine? empty = TryReadLine(); - if (empty != null && !empty.Value.IsEmpty) - throw new NotSupportedException("Multi-filed events not supported."); - if (!line.Value.IsComment) - return line; // skip comment lines - } - } + // TODO: recall proper semantics for Current and fix null issues. + public ServerSentEvent Current => _current!.Value; - // TODO: we should support cancellation tokens, but StreamReader does not in NS2 - public async Task TryReadSingleFieldEventAsync() + public async ValueTask MoveNextAsync() { + // TODO: don't reallocate every call to MoveNext + // TODO: UTF-8 all the way down possible here? + List fields = []; + + // TODO: How to handle the CancellationToken? + + // TODO: Call different ConfigureAwait variant in this context, or no? while (true) { - SseLine? line = await TryReadLineAsync().ConfigureAwait(false); + string line = await _reader.ReadLineAsync().ConfigureAwait(false); + if (line == null) - return null; - if (line.Value.IsEmpty) - throw new InvalidDataException("event expected."); - SseLine? empty = await TryReadLineAsync().ConfigureAwait(false); - if (empty != null && !empty.Value.IsEmpty) - throw new NotSupportedException("Multi-filed events not supported."); - if (!line.Value.IsComment) - return line; // skip comment lines + { + // A null line indicates end of input + return false; + } + + // TODO: another way to rework this for perf, clarity, or is + // this optimal? + else if (line.Length == 0) + { + // An empty line should dispatch an event for pending accumulated fields + ServerSentEvent nextEvent = new(fields); + fields = []; + _current = nextEvent; + return true; + } + else if (line[0] == ':') + { + // A line beginning with a colon is a comment and should be ignored + continue; + } + else + { + // Otherwise, process the the field + value and accumulate it for the next dispatched event + fields.Add(new ServerSentEventField(line)); + } } } - public SseLine? TryReadLine() + public ValueTask DisposeAsync() { - string lineText = _reader.ReadLine(); - if (lineText == null) - return null; - if (lineText.Length == 0) - return SseLine.Empty; - if (TryParseLine(lineText, out SseLine line)) - return line; - return null; + // TODO: revisit per platforms where async dispose is available. + _stream?.Dispose(); + _reader?.Dispose(); + return new ValueTask(); } + } - // TODO: we should support cancellation tokens, but StreamReader does not in NS2 - public async Task TryReadLineAsync() - { - string lineText = await _reader.ReadLineAsync().ConfigureAwait(false); - if (lineText == null) - return null; - if (lineText.Length == 0) - return SseLine.Empty; - if (TryParseLine(lineText, out SseLine line)) - return line; - return null; - } + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } - private static bool TryParseLine(string lineText, out SseLine line) + private void Dispose(bool disposing) + { + if (!_disposedValue) { - if (lineText.Length == 0) + if (disposing) { - line = default; - return false; + _stream.Dispose(); } - ReadOnlySpan lineSpan = lineText.AsSpan(); - int colonIndex = lineSpan.IndexOf(':'); - ReadOnlySpan fieldValue = lineSpan.Slice(colonIndex + 1); - - bool hasSpace = false; - if (fieldValue.Length > 0 && fieldValue[0] == ' ') - hasSpace = true; - line = new SseLine(lineText, colonIndex, hasSpace); - return true; + _disposedValue = true; } + } - private void Dispose(bool disposing) - { - if (!_disposedValue) - { - if (disposing) - { - _reader.Dispose(); - _stream.Dispose(); - } - - _disposedValue = true; - } - } - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - } \ No newline at end of file + ValueTask IAsyncDisposable.DisposeAsync() + { + // TODO: revisit per platforms where async dispose is available. + return new ValueTask(); + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/StreamedEventCollection.cs b/.dotnet/src/Utility/StreamedEventCollection.cs new file mode 100644 index 000000000..d9584fa1d --- /dev/null +++ b/.dotnet/src/Utility/StreamedEventCollection.cs @@ -0,0 +1,33 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +namespace OpenAI; + +// TODO: Make it work for non-JSON models too. +// TODO: The list part is a hack - pull it out +internal abstract class StreamedEventCollection : List, IJsonModel> + // This just promises me I can deserialize a collection of serializable models. + // Should it be more general? + where TModel : IJsonModel +{ + public abstract StreamedEventCollection Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options); + + public abstract StreamedEventCollection Create(BinaryData data, ModelReaderWriterOptions options); + + public string GetFormatFromOptions(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public void Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public BinaryData Write(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } +} diff --git a/.dotnet/src/Utility/StreamingClientResultOfT.cs b/.dotnet/src/Utility/StreamingClientResultOfT.cs new file mode 100644 index 000000000..7483dad71 --- /dev/null +++ b/.dotnet/src/Utility/StreamingClientResultOfT.cs @@ -0,0 +1,31 @@ +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Threading; + +namespace OpenAI; + +#pragma warning disable CS1591 // public XML comments + +/// +/// Represents an operation response with streaming content that can be deserialized and enumerated while the response +/// is still being received. +/// +/// The data type representative of distinct, streamable items. +// TODO: Revisit the IDisposable question +public abstract class StreamingClientResult : ClientResult, IAsyncEnumerable + // TODO: Note that constraining the T means the implementation can use + // ModelReaderWriter for deserialization. + where T : IPersistableModel +{ + protected StreamingClientResult(PipelineResponse response) : base(response) + { + } + + // Note that if the implementation disposes the stream, the caller can only + // enumerate the results once. I think this makes sense, but we should + // make sure architects agree. + public abstract IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default); +} + +#pragma warning restore CS1591 // public XML comments \ No newline at end of file diff --git a/.dotnet/src/Utility/StreamingResult.cs b/.dotnet/src/Utility/StreamingResult.cs deleted file mode 100644 index a1b6ff538..000000000 --- a/.dotnet/src/Utility/StreamingResult.cs +++ /dev/null @@ -1,95 +0,0 @@ -using System.ClientModel; -using System.ClientModel; -using System.ClientModel.Primitives; -using System.Threading; -using System.Collections.Generic; -using System; - -namespace OpenAI; - -/// -/// Represents an operation response with streaming content that can be deserialized and enumerated while the response -/// is still being received. -/// -/// The data type representative of distinct, streamable items. -public class StreamingClientResult - : IDisposable - , IAsyncEnumerable -{ - private ClientResult _rawResult { get; } - private IAsyncEnumerable _asyncEnumerableSource { get; } - private bool _disposedValue { get; set; } - - private StreamingClientResult() { } - - private StreamingClientResult( - ClientResult rawResult, - Func> asyncEnumerableProcessor) - { - _rawResult = rawResult; - _asyncEnumerableSource = asyncEnumerableProcessor.Invoke(rawResult); - } - - /// - /// Creates a new instance of using the provided underlying HTTP response. The - /// provided function will be used to resolve the response into an asynchronous enumeration of streamed response - /// items. - /// - /// The HTTP response. - /// - /// The function that will resolve the provided response into an IAsyncEnumerable. - /// - /// - /// A new instance of that will be capable of asynchronous enumeration of - /// items from the HTTP response. - /// - internal static StreamingClientResult CreateFromResponse( - ClientResult result, - Func> asyncEnumerableProcessor) - { - return new(result, asyncEnumerableProcessor); - } - - /// - /// Gets the underlying instance that this may enumerate - /// over. - /// - /// The instance attached to this . - public PipelineResponse GetRawResponse() => _rawResult.GetRawResponse(); - - /// - /// Gets the asynchronously enumerable collection of distinct, streamable items in the response. - /// - /// - /// The return value of this method may be used with the "await foreach" statement. - /// - /// As explicitly implements , callers may - /// enumerate a instance directly instead of calling this method. - /// - /// - /// - public IAsyncEnumerable EnumerateValues() => this; - - /// - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - - /// - protected virtual void Dispose(bool disposing) - { - if (!_disposedValue) - { - if (disposing) - { - _rawResult?.GetRawResponse()?.Dispose(); - } - _disposedValue = true; - } - } - - IAsyncEnumerator IAsyncEnumerable.GetAsyncEnumerator(CancellationToken cancellationToken) - => _asyncEnumerableSource.GetAsyncEnumerator(cancellationToken); -} \ No newline at end of file diff --git a/.dotnet/tests/OpenAI.Tests.csproj b/.dotnet/tests/OpenAI.Tests.csproj index 59b95785f..cf68d6d01 100644 --- a/.dotnet/tests/OpenAI.Tests.csproj +++ b/.dotnet/tests/OpenAI.Tests.csproj @@ -1,6 +1,6 @@ - net8.0 + net7.0 $(NoWarn);CS1591 latest @@ -14,6 +14,6 @@ - + \ No newline at end of file diff --git a/.dotnet/tests/Samples/Chat/Sample02_StreamingChatAsync.cs b/.dotnet/tests/Samples/Chat/Sample02_StreamingChatAsync.cs index 6ce1a96f0..e59ff628e 100644 --- a/.dotnet/tests/Samples/Chat/Sample02_StreamingChatAsync.cs +++ b/.dotnet/tests/Samples/Chat/Sample02_StreamingChatAsync.cs @@ -1,6 +1,8 @@ using NUnit.Framework; using OpenAI.Chat; using System; +using System.ClientModel.Primitives; +using System.IO; using System.Threading.Tasks; namespace OpenAI.Samples @@ -21,6 +23,9 @@ public async Task Sample02_StreamingChatAsync() { Console.Write(chatUpdate.ContentUpdate); } + + PipelineResponse response = result.GetRawResponse(); + Stream stream = response.ContentStream; } } } diff --git a/.dotnet/tests/TestScenarios/AssistantTests.cs b/.dotnet/tests/TestScenarios/AssistantTests.cs index 005bf1234..aa19fc2d8 100644 --- a/.dotnet/tests/TestScenarios/AssistantTests.cs +++ b/.dotnet/tests/TestScenarios/AssistantTests.cs @@ -1,7 +1,9 @@ using NUnit.Framework; using OpenAI.Assistants; +using OpenAI.Chat; using System; using System.ClientModel; +using System.Collections.Generic; using System.Threading.Tasks; using static OpenAI.Tests.TestHelpers; @@ -127,6 +129,104 @@ public async Task BasicFunctionToolWorks() Assert.That(runResult.Value.Status, Is.Not.EqualTo(RunStatus.RequiresAction)); } + [Test] + public async Task SimpleStreamingRunWorks() + { + AssistantClient client = GetTestClient(); + Assistant assistant = await CreateCommonTestAssistantAsync(); + + StreamingClientResult runUpdateResult = client.CreateThreadAndRunStreaming( + assistant.Id, + new ThreadCreationOptions() + { + Messages = + { + "Hello, assistant! Can you help me?", + } + }); + Assert.That(runUpdateResult, Is.Not.Null); + await foreach (StreamingUpdate runUpdate in runUpdateResult) + { + if (runUpdate is StreamingMessageCreation messageCreation) + { + Console.WriteLine($"Message created, id={messageCreation.Message.Id}"); + } + if (runUpdate is StreamingMessageUpdate messageUpdate) + { + Console.Write(messageUpdate.ContentUpdate.GetText()); + } + if (runUpdate is StreamingMessageCompletion messageCompletion) + { + Console.WriteLine(); + Console.WriteLine($"Message complete: {messageCompletion.Message.ContentItems[0].GetText()}"); + } + } + } + + [Test] + public async Task StreamingWithToolsWorks() + { + AssistantClient client = GetTestClient(); + ClientResult assistantResult = await client.CreateAssistantAsync("gpt-3.5-turbo", new AssistantCreationOptions() + { + Instructions = "You are a helpful math assistant that helps with visualizing equations. Use the code interpreter tool when asked to generate images. Use provided functions to resolve appropriate unknown values", + Tools = + { + new CodeInterpreterToolDefinition(), + new FunctionToolDefinition("get_boilerplate_equation", "Retrieves a predefined 'boilerplate equation' from the caller."), + }, + Metadata = { [s_cleanupMetadataKey] = "true" }, + }); + Assistant assistant = assistantResult.Value; + Assert.That(assistant, Is.Not.Null); + + ClientResult threadResult = await client.CreateThreadAsync(new ThreadCreationOptions() + { + Messages = + { + "Please make a graph for my boilerplate equation", + }, + }); + AssistantThread thread = threadResult.Value; + Assert.That(thread, Is.Not.Null); + + StreamingClientResult streamingResult = await client.CreateRunStreamingAsync(thread.Id, assistant.Id); + Assert.That(streamingResult, Is.Not.Null); + List requiredActions = []; + ThreadRun initialStreamedRun = null; + await foreach (StreamingUpdate streamingUpdate in streamingResult) + { + if (streamingUpdate is StreamingRunCreation streamingRunCreation) + { + initialStreamedRun = streamingRunCreation.Run; + } + if (streamingUpdate is StreamingRequiredAction streamedRequiredAction) + { + requiredActions.Add(streamedRequiredAction.RequiredAction); + } + Console.WriteLine(streamingUpdate.GetRawSseEvent().ToString()); + } + Assert.That(initialStreamedRun?.Id, Is.Not.Null.Or.Empty); + Assert.That(requiredActions, Is.Not.Empty); + + List toolOutputs = []; + foreach (RunRequiredAction requiredAction in requiredActions) + { + if (requiredAction is RequiredFunctionToolCall functionCall) + { + if (functionCall.Name == "get_boilerplate_equation") + { + toolOutputs.Add(new(functionCall, "y = 14x - 3")); + } + } + } + streamingResult = await client.SubmitToolOutputsStreamingAsync(thread.Id, initialStreamedRun.Id, toolOutputs); + await foreach (StreamingUpdate streamingUpdate in streamingResult) + { + Console.WriteLine(streamingUpdate.GetRawSseEvent().ToString()); + } + } + private async Task CreateCommonTestAssistantAsync() { AssistantClient client = new(); @@ -142,7 +242,8 @@ private async Task CreateCommonTestAssistantAsync() return newAssistantResult.Value; } - private async Task DeleteRecentTestThings() + [TearDown] + protected async Task DeleteRecentTestThings() { AssistantClient client = GetTestClient(); foreach(Assistant assistant in client.GetAssistants().Value) diff --git a/messages/models.tsp b/messages/models.tsp index 67d78db29..16c059d4b 100644 --- a/messages/models.tsp +++ b/messages/models.tsp @@ -72,6 +72,33 @@ model MessageObject { /** The [thread](/docs/api-reference/threads) ID that this message belongs to. */ thread_id: string; + /** + * The status of the message, which can be either in_progress, incomplete, or completed. + */ + status: "in_progress" | "incomplete" | "completed"; + + /** + * On an incomplete message, details about why the message is incomplete. + */ + incomplete_details: { + /** + * The reason the message is incomplete. + */ + reason: string; + } | null; + + /** + * The Unix timestamp at which the message was completed. + */ + @encode("unixTimestamp", int32) + completed_at: utcDateTime | null; + + /** + * The Unix timestamp at which the message was marked as incomplete. + */ + @encode("unixTimestamp", int32) + incomplete_at: utcDateTime | null; + /** The entity that produced the message. One of `user` or `assistant`. */ role: "user" | "assistant"; diff --git a/runs/models.tsp b/runs/models.tsp index 011ec94a9..eff9b0e15 100644 --- a/runs/models.tsp +++ b/runs/models.tsp @@ -42,6 +42,12 @@ model CreateRunRequest { */ @extension("x-oaiTypeLabel", "map") metadata?: Record | null; + + /** + * If true, returns a stream of events that happen during the Run as server-sent events, + * terminating when the Run enters a terminal state with a data: [DONE] message. + */ + stream?: boolean; } model ModifyRunRequest { @@ -87,6 +93,12 @@ model CreateThreadAndRunRequest { */ @extension("x-oaiTypeLabel", "map") metadata?: Record | null; + + /** + * If true, returns a stream of events that happen during the Run as server-sent events, + * terminating when the Run enters a terminal state with a data: [DONE] message. + */ + stream?: boolean; } model SubmitToolOutputsRunRequest { @@ -100,6 +112,12 @@ model SubmitToolOutputsRunRequest { /** The output of the tool call to be submitted to continue the run. */ output?: string; }[]; + + /** + * If true, returns a stream of events that happen during the Run as server-sent events, + * terminating when the Run enters a terminal state. + */ + stream?: boolean; } model ListRunsResponse { diff --git a/tsp-output/@typespec/openapi3/openapi.yaml b/tsp-output/@typespec/openapi3/openapi.yaml index 224f0b7a0..83319a41b 100644 --- a/tsp-output/@typespec/openapi3/openapi.yaml +++ b/tsp-output/@typespec/openapi3/openapi.yaml @@ -3638,6 +3638,11 @@ components: additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. x-oaiTypeLabel: map + stream: + type: boolean + description: |- + If true, returns a stream of events that happen during the Run as server-sent events, + terminating when the Run enters a terminal state with a data: [DONE] message. CreateRunRequestTool: oneOf: - $ref: '#/components/schemas/AssistantToolsCode' @@ -3743,6 +3748,11 @@ components: additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. x-oaiTypeLabel: map + stream: + type: boolean + description: |- + If true, returns a stream of events that happen during the Run as server-sent events, + terminating when the Run enters a terminal state with a data: [DONE] message. CreateThreadRequest: type: object properties: @@ -4696,6 +4706,10 @@ components: - object - created_at - thread_id + - status + - incomplete_details + - completed_at + - incomplete_at - role - content - assistant_id @@ -4718,6 +4732,33 @@ components: thread_id: type: string description: The [thread](/docs/api-reference/threads) ID that this message belongs to. + status: + type: string + enum: + - in_progress + - incomplete + - completed + description: The status of the message, which can be either in_progress, incomplete, or completed. + incomplete_details: + type: object + properties: + reason: + type: string + description: The reason the message is incomplete. + required: + - reason + nullable: true + description: On an incomplete message, details about why the message is incomplete. + completed_at: + type: string + format: date-time + nullable: true + description: The Unix timestamp at which the message was completed. + incomplete_at: + type: string + format: date-time + nullable: true + description: The Unix timestamp at which the message was marked as incomplete. role: type: string enum: @@ -5512,6 +5553,11 @@ components: type: string description: The output of the tool call to be submitted to continue the run. description: A list of tools for which the outputs are being submitted. + stream: + type: boolean + description: |- + If true, returns a stream of events that happen during the Run as server-sent events, + terminating when the Run enters a terminal state. SuffixString: type: string minLength: 1