-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore: [openai] (PoC) Generated model classes #266
base: main
Are you sure you want to change the base?
Changes from all commits
c7da29d
f7fd3f8
1072172
8554091
483db01
bbc0146
433d84a
4bc431b
d8dd4c0
0620966
cc7e833
753cbc4
8df9ef8
c5d45b8
7db4500
4ebc2a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package com.sap.ai.sdk.foundationmodels.openai; | ||
|
||
import com.fasterxml.jackson.annotation.JsonSubTypes; | ||
import com.fasterxml.jackson.annotation.JsonTypeInfo; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionResponse; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionStreamResponse; | ||
import lombok.AccessLevel; | ||
import lombok.NoArgsConstructor; | ||
|
||
@NoArgsConstructor(access = AccessLevel.PRIVATE) | ||
final class JacksonMixins { | ||
|
||
@JsonTypeInfo(use = JsonTypeInfo.Id.NONE) | ||
interface CreateChatCompletionStreamResponseMixIn {} | ||
|
||
@JsonTypeInfo(use = JsonTypeInfo.Id.NONE) | ||
interface CreateChatCompletionResponseMixIn {} | ||
|
||
@JsonTypeInfo( | ||
use = JsonTypeInfo.Id.NAME, | ||
property = "object", | ||
defaultImpl = CreateChatCompletionResponse.class, | ||
visible = true) | ||
@JsonSubTypes({ | ||
@JsonSubTypes.Type(value = CreateChatCompletionResponse.class, name = "chat.completion"), | ||
@JsonSubTypes.Type( | ||
value = CreateChatCompletionStreamResponse.class, | ||
name = "chat.completion.chunk"), | ||
}) | ||
public interface ChatCompletionCreate200ResponseMixIn {} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package com.sap.ai.sdk.foundationmodels.openai; | ||
|
||
import com.fasterxml.jackson.annotation.JsonCreator; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import com.google.common.annotations.Beta; | ||
import com.sap.ai.sdk.core.common.StreamedDelta; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionsCreate200Response; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.CompletionUsage; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionResponse; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionStreamResponse; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import javax.annotation.Nonnull; | ||
import javax.annotation.Nullable; | ||
import lombok.RequiredArgsConstructor; | ||
import lombok.Value; | ||
|
||
/** | ||
* OpenAI chat completion output delta for streaming. | ||
* | ||
* @since 1.3.0 | ||
*/ | ||
@Beta | ||
@Value | ||
@RequiredArgsConstructor(onConstructor_ = @JsonCreator) | ||
public class OpenAiChatCompletionDelta implements StreamedDelta { | ||
ChatCompletionsCreate200Response originalResponse; | ||
|
||
@Nonnull | ||
@Override | ||
public String getDeltaContent() { | ||
if (getOriginalResponse() instanceof CreateChatCompletionStreamResponse response) { | ||
final var choices = response.getChoices(); | ||
if (!choices.isEmpty() && choices.get(0).getIndex() == 0) { | ||
final var message = choices.get(0).getDelta(); | ||
if (message != null) { | ||
return Objects.requireNonNullElse(message.getContent(), ""); | ||
} | ||
} | ||
} | ||
if (getOriginalResponse() instanceof CreateChatCompletionResponse response) { | ||
final var choices = response.getChoices(); | ||
if (!choices.isEmpty() && choices.get(0).getIndex() == 0) { | ||
final var delta = choices.get(0).getCustomField("delta"); // .getMessage() does not work | ||
if (delta instanceof String message) { | ||
return message; | ||
} | ||
} | ||
} | ||
return ""; | ||
} | ||
|
||
@Nullable | ||
@Override | ||
public String getFinishReason() { | ||
if (getOriginalResponse() instanceof CreateChatCompletionStreamResponse response) { | ||
final var choices = response.getChoices(); | ||
if (!choices.isEmpty()) { | ||
final var finishReason = choices.get(0).getFinishReason(); | ||
return finishReason != null ? finishReason.getValue() : null; | ||
} | ||
} | ||
if (getOriginalResponse() instanceof CreateChatCompletionResponse response) { | ||
final var choices = response.getChoices(); | ||
if (!choices.isEmpty()) { | ||
final var finishReason = choices.get(0).getFinishReason(); | ||
return finishReason != null ? finishReason.getValue() : null; | ||
} | ||
} | ||
return null; | ||
} | ||
|
||
/** | ||
* Get the completion usage from the response, or null if it is not available. | ||
* | ||
* @param objectMapper The object mapper to use for conversion. | ||
* @return The completion usage or null. | ||
*/ | ||
@Nullable | ||
public CompletionUsage getCompletionUsage(@Nonnull final ObjectMapper objectMapper) { | ||
if (getOriginalResponse() instanceof CreateChatCompletionStreamResponse response | ||
&& response.getCustomFieldNames().contains("usage") | ||
&& response.getCustomField("usage") instanceof Map<?, ?> usage) { | ||
return objectMapper.convertValue(usage, CompletionUsage.class); | ||
} | ||
if (getOriginalResponse() instanceof CreateChatCompletionResponse response) { | ||
return response.getUsage(); | ||
} | ||
return null; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,14 +10,18 @@ | |
import com.sap.ai.sdk.core.common.ClientResponseHandler; | ||
import com.sap.ai.sdk.core.common.ClientStreamingHandler; | ||
import com.sap.ai.sdk.core.common.StreamedDelta; | ||
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta; | ||
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput; | ||
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters; | ||
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatSystemMessage; | ||
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage; | ||
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput; | ||
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters; | ||
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiError; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionRequestSystemMessage; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionRequestSystemMessageContent; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionRequestUserMessage; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionRequestUserMessageContent; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionStreamOptions; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionsCreate200Response; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionRequest; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionResponse; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionStreamResponse; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreate200Response; | ||
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreateRequest; | ||
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; | ||
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; | ||
import com.sap.cloud.sdk.cloudplatform.connectivity.Destination; | ||
|
@@ -126,13 +130,24 @@ public OpenAiClient withSystemPrompt(@Nonnull final String systemPrompt) { | |
* @throws OpenAiClientException if the request fails | ||
*/ | ||
@Nonnull | ||
public OpenAiChatCompletionOutput chatCompletion(@Nonnull final String prompt) | ||
public CreateChatCompletionResponse chatCompletion(@Nonnull final String prompt) | ||
throws OpenAiClientException { | ||
final OpenAiChatCompletionParameters parameters = new OpenAiChatCompletionParameters(); | ||
final CreateChatCompletionRequest parameters = new CreateChatCompletionRequest(); | ||
|
||
if (systemPrompt != null) { | ||
parameters.addMessages(new OpenAiChatSystemMessage().setContent(systemPrompt)); | ||
parameters.addMessagesItem( | ||
new ChatCompletionRequestSystemMessage() | ||
.role(ChatCompletionRequestSystemMessage.RoleEnum.SYSTEM) | ||
.content(ChatCompletionRequestSystemMessageContent.create(systemPrompt))); | ||
} | ||
parameters.addMessages(new OpenAiChatUserMessage().addText(prompt)); | ||
parameters | ||
.addMessagesItem( | ||
new ChatCompletionRequestUserMessage() | ||
.role(ChatCompletionRequestUserMessage.RoleEnum.USER) | ||
.content(ChatCompletionRequestUserMessageContent.create(prompt))) | ||
.functions(null) | ||
.tools(null) | ||
.parallelToolCalls(null); | ||
return chatCompletion(parameters); | ||
} | ||
|
||
|
@@ -144,16 +159,16 @@ public OpenAiChatCompletionOutput chatCompletion(@Nonnull final String prompt) | |
* @throws OpenAiClientException if the request fails | ||
*/ | ||
@Nonnull | ||
public OpenAiChatCompletionOutput chatCompletion( | ||
@Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException { | ||
public CreateChatCompletionResponse chatCompletion( | ||
@Nonnull final CreateChatCompletionRequest parameters) throws OpenAiClientException { | ||
warnIfUnsupportedUsage(); | ||
return execute("/chat/completions", parameters, OpenAiChatCompletionOutput.class); | ||
return execute("/chat/completions", parameters, CreateChatCompletionResponse.class); | ||
} | ||
|
||
/** | ||
* Stream a completion for the given prompt. Returns a <b>lazily</b> populated stream of text | ||
* chunks. To access more details about the individual chunks, use {@link | ||
* #streamChatCompletionDeltas(OpenAiChatCompletionParameters)}. | ||
* #streamChatCompletionDeltas(CreateChatCompletionRequest)}. | ||
* | ||
* <p>The stream should be consumed using a try-with-resources block to ensure that the underlying | ||
* HTTP connection is closed. | ||
|
@@ -173,22 +188,33 @@ public OpenAiChatCompletionOutput chatCompletion( | |
* @param prompt a text message. | ||
* @return A stream of message deltas | ||
* @throws OpenAiClientException if the request fails or if the finish reason is content_filter | ||
* @see #streamChatCompletionDeltas(OpenAiChatCompletionParameters) | ||
* @see #streamChatCompletionDeltas(CreateChatCompletionRequest) | ||
*/ | ||
@Nonnull | ||
public Stream<String> streamChatCompletion(@Nonnull final String prompt) | ||
throws OpenAiClientException { | ||
final OpenAiChatCompletionParameters parameters = new OpenAiChatCompletionParameters(); | ||
final CreateChatCompletionRequest parameters = new CreateChatCompletionRequest(); | ||
|
||
if (systemPrompt != null) { | ||
parameters.addMessages(new OpenAiChatSystemMessage().setContent(systemPrompt)); | ||
parameters.addMessagesItem( | ||
new ChatCompletionRequestSystemMessage() | ||
.role(ChatCompletionRequestSystemMessage.RoleEnum.SYSTEM) | ||
.content(ChatCompletionRequestSystemMessageContent.create(systemPrompt))); | ||
} | ||
parameters.addMessages(new OpenAiChatUserMessage().addText(prompt)); | ||
final var userMessage = | ||
new ChatCompletionRequestUserMessage() | ||
.role(ChatCompletionRequestUserMessage.RoleEnum.USER) | ||
.content(ChatCompletionRequestUserMessageContent.create(prompt)); | ||
parameters.addMessagesItem(userMessage).tools(null).functions(null).parallelToolCalls(null); | ||
|
||
return streamChatCompletionDeltas(parameters) | ||
.map(OpenAiChatCompletionDelta.class::cast) | ||
.peek(OpenAiClient::throwOnContentFilter) | ||
.map(OpenAiChatCompletionDelta::getDeltaContent); | ||
} | ||
|
||
private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelta delta) { | ||
private static void throwOnContentFilter( | ||
@Nonnull final com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionDelta delta) { | ||
final String finishReason = delta.getFinishReason(); | ||
if (finishReason != null && finishReason.equals("content_filter")) { | ||
throw new OpenAiClientException("Content filter filtered the output."); | ||
|
@@ -224,9 +250,9 @@ private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelt | |
*/ | ||
@Nonnull | ||
public Stream<OpenAiChatCompletionDelta> streamChatCompletionDeltas( | ||
@Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException { | ||
@Nonnull final CreateChatCompletionRequest parameters) throws OpenAiClientException { | ||
warnIfUnsupportedUsage(); | ||
parameters.enableStreaming(); | ||
parameters.stream(true).streamOptions(new ChatCompletionStreamOptions().includeUsage(true)); | ||
return executeStream("/chat/completions", parameters, OpenAiChatCompletionDelta.class); | ||
} | ||
|
||
|
@@ -246,9 +272,9 @@ private void warnIfUnsupportedUsage() { | |
* @throws OpenAiClientException if the request fails | ||
*/ | ||
@Nonnull | ||
public OpenAiEmbeddingOutput embedding(@Nonnull final OpenAiEmbeddingParameters parameters) | ||
public EmbeddingsCreate200Response embedding(@Nonnull final EmbeddingsCreateRequest parameters) | ||
throws OpenAiClientException { | ||
return execute("/embeddings", parameters, OpenAiEmbeddingOutput.class); | ||
return execute("/embeddings", parameters, EmbeddingsCreate200Response.class); | ||
} | ||
|
||
@Nonnull | ||
|
@@ -300,6 +326,17 @@ private <D extends StreamedDelta> Stream<D> streamRequest( | |
try { | ||
final var client = ApacheHttpClient5Accessor.getHttpClient(destination); | ||
return new ClientStreamingHandler<>(deltaType, OpenAiError.class, OpenAiClientException::new) | ||
.objectMapper( | ||
JACKSON | ||
.addMixIn( | ||
CreateChatCompletionResponse.class, | ||
JacksonMixins.CreateChatCompletionResponseMixIn.class) | ||
.addMixIn( | ||
CreateChatCompletionStreamResponse.class, | ||
JacksonMixins.CreateChatCompletionStreamResponseMixIn.class) | ||
.addMixIn( | ||
ChatCompletionsCreate200Response.class, | ||
JacksonMixins.ChatCompletionCreate200ResponseMixIn.class)) | ||
.handleStreamingResponse(client.executeOpen(null, request, null)); | ||
Comment on lines
328
to
340
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Comment) Workaround because Jackson is unable to determine deserializer by itself for
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Question) Maybe I missed the problem. But, I don't see the reason why deduction would not work as there is an additional property According to my understanding, as long as there is a distinguishing property, Jackson should be able the resolve the subtype. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The actual payload for streamed responses differs significantly from the specification; it is inconsistent. Whatever annotation / setting we choose, it will be a compromise. If you find deduction working, or if you prefer a different mixin - please be my guest and change it as part of the upcoming convenience API draft. |
||
} catch (final IOException e) { | ||
throw new OpenAiClientException("Request to OpenAI model failed", e); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Comment)
This is the unfortunate diff expected for regular usage.
Without convenience API.