Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: [openai] (PoC) Generated model classes #266

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
3 changes: 3 additions & 0 deletions .pipeline/checkstyle-suppressions.xml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
<!-- Suppress generated clients -->
<suppress files="[/\\]core[/\\]client[/\\]" checks=".*"/>
<suppress files="[/\\]core[/\\]model[/\\]" checks=".*"/>
<suppress files="[/\\]openai[/\\]model2[/\\]" checks=".*"/>
<suppress files="[/\\]orchestration[/\\]model[/\\]" checks=".*"/>
<!-- Suppress TODOs -->
<suppress files="OpenAiChatMessage.java" checks="TodoComment" lines="257,7" />
<suppress files="ChatCompletionResponseMessage.java" checks="TodoComment" lines="53,34" />
<suppress files="CreateChatCompletionRequest.java" checks="TodoComment" lines="73,47" />
</suppressions>
12 changes: 6 additions & 6 deletions foundation-models/openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
</developers>
<properties>
<project.rootdir>${project.basedir}/../../</project.rootdir>
<coverage.complexity>71%</coverage.complexity>
<coverage.line>80%</coverage.line>
<coverage.instruction>76%</coverage.instruction>
<coverage.branch>69%</coverage.branch>
<coverage.method>83%</coverage.method>
<coverage.class>84%</coverage.class>
<coverage.complexity>32%</coverage.complexity>
<coverage.line>42%</coverage.line>
<coverage.instruction>42%</coverage.instruction>
<coverage.branch>16%</coverage.branch>
<coverage.method>48%</coverage.method>
<coverage.class>40%</coverage.class>
</properties>
<dependencies>
<dependency>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.sap.ai.sdk.foundationmodels.openai;

import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionResponse;
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionStreamResponse;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
final class JacksonMixins {

@JsonTypeInfo(use = JsonTypeInfo.Id.NONE)
interface CreateChatCompletionStreamResponseMixIn {}

@JsonTypeInfo(use = JsonTypeInfo.Id.NONE)
interface CreateChatCompletionResponseMixIn {}

@JsonTypeInfo(
use = JsonTypeInfo.Id.NAME,
property = "object",
defaultImpl = CreateChatCompletionResponse.class,
visible = true)
@JsonSubTypes({
@JsonSubTypes.Type(value = CreateChatCompletionResponse.class, name = "chat.completion"),
@JsonSubTypes.Type(
value = CreateChatCompletionStreamResponse.class,
name = "chat.completion.chunk"),
})
public interface ChatCompletionCreate200ResponseMixIn {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package com.sap.ai.sdk.foundationmodels.openai;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.Beta;
import com.sap.ai.sdk.core.common.StreamedDelta;
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionsCreate200Response;
import com.sap.ai.sdk.foundationmodels.openai.model2.CompletionUsage;
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionResponse;
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionStreamResponse;
import java.util.Map;
import java.util.Objects;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.RequiredArgsConstructor;
import lombok.Value;

/**
* OpenAI chat completion output delta for streaming.
*
* @since 1.3.0
*/
@Beta
@Value
@RequiredArgsConstructor(onConstructor_ = @JsonCreator)
public class OpenAiChatCompletionDelta implements StreamedDelta {
ChatCompletionsCreate200Response originalResponse;

@Nonnull
@Override
public String getDeltaContent() {
if (getOriginalResponse() instanceof CreateChatCompletionStreamResponse response) {
final var choices = response.getChoices();
if (!choices.isEmpty() && choices.get(0).getIndex() == 0) {
final var message = choices.get(0).getDelta();
if (message != null) {
return Objects.requireNonNullElse(message.getContent(), "");
}
}
}
if (getOriginalResponse() instanceof CreateChatCompletionResponse response) {
final var choices = response.getChoices();
if (!choices.isEmpty() && choices.get(0).getIndex() == 0) {
final var delta = choices.get(0).getCustomField("delta"); // .getMessage() does not work
if (delta instanceof String message) {
return message;
}
}
}
return "";
}

@Nullable
@Override
public String getFinishReason() {
if (getOriginalResponse() instanceof CreateChatCompletionStreamResponse response) {
final var choices = response.getChoices();
if (!choices.isEmpty()) {
final var finishReason = choices.get(0).getFinishReason();
return finishReason != null ? finishReason.getValue() : null;
}
}
if (getOriginalResponse() instanceof CreateChatCompletionResponse response) {
final var choices = response.getChoices();
if (!choices.isEmpty()) {
final var finishReason = choices.get(0).getFinishReason();
return finishReason != null ? finishReason.getValue() : null;
}
}
return null;
}

/**
* Get the completion usage from the response, or null if it is not available.
*
* @param objectMapper The object mapper to use for conversion.
* @return The completion usage or null.
*/
@Nullable
public CompletionUsage getCompletionUsage(@Nonnull final ObjectMapper objectMapper) {
if (getOriginalResponse() instanceof CreateChatCompletionStreamResponse response
&& response.getCustomFieldNames().contains("usage")
&& response.getCustomField("usage") instanceof Map<?, ?> usage) {
return objectMapper.convertValue(usage, CompletionUsage.class);
}
if (getOriginalResponse() instanceof CreateChatCompletionResponse response) {
return response.getUsage();
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@
import com.sap.ai.sdk.core.common.ClientResponseHandler;
import com.sap.ai.sdk.core.common.ClientStreamingHandler;
import com.sap.ai.sdk.core.common.StreamedDelta;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatSystemMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiError;
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionRequestSystemMessage;
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionRequestSystemMessageContent;
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionRequestUserMessage;
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionRequestUserMessageContent;
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionStreamOptions;
import com.sap.ai.sdk.foundationmodels.openai.model2.ChatCompletionsCreate200Response;
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionRequest;
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionResponse;
import com.sap.ai.sdk.foundationmodels.openai.model2.CreateChatCompletionStreamResponse;
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreate200Response;
import com.sap.ai.sdk.foundationmodels.openai.model2.EmbeddingsCreateRequest;
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
import com.sap.cloud.sdk.cloudplatform.connectivity.Destination;
Expand Down Expand Up @@ -126,13 +130,24 @@ public OpenAiClient withSystemPrompt(@Nonnull final String systemPrompt) {
* @throws OpenAiClientException if the request fails
*/
@Nonnull
public OpenAiChatCompletionOutput chatCompletion(@Nonnull final String prompt)
public CreateChatCompletionResponse chatCompletion(@Nonnull final String prompt)
throws OpenAiClientException {
final OpenAiChatCompletionParameters parameters = new OpenAiChatCompletionParameters();
final CreateChatCompletionRequest parameters = new CreateChatCompletionRequest();

if (systemPrompt != null) {
parameters.addMessages(new OpenAiChatSystemMessage().setContent(systemPrompt));
parameters.addMessagesItem(
new ChatCompletionRequestSystemMessage()
.role(ChatCompletionRequestSystemMessage.RoleEnum.SYSTEM)
.content(ChatCompletionRequestSystemMessageContent.create(systemPrompt)));
}
parameters.addMessages(new OpenAiChatUserMessage().addText(prompt));
parameters
.addMessagesItem(
new ChatCompletionRequestUserMessage()
.role(ChatCompletionRequestUserMessage.RoleEnum.USER)
.content(ChatCompletionRequestUserMessageContent.create(prompt)))
.functions(null)
.tools(null)
.parallelToolCalls(null);
return chatCompletion(parameters);
}
Comment on lines 132 to 152
Copy link
Contributor Author

@newtork newtork Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Comment)

This is the unfortunate diff expected for regular usage.
Without convenience API.


Expand All @@ -144,16 +159,16 @@ public OpenAiChatCompletionOutput chatCompletion(@Nonnull final String prompt)
* @throws OpenAiClientException if the request fails
*/
@Nonnull
public OpenAiChatCompletionOutput chatCompletion(
@Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException {
public CreateChatCompletionResponse chatCompletion(
@Nonnull final CreateChatCompletionRequest parameters) throws OpenAiClientException {
warnIfUnsupportedUsage();
return execute("/chat/completions", parameters, OpenAiChatCompletionOutput.class);
return execute("/chat/completions", parameters, CreateChatCompletionResponse.class);
}

/**
* Stream a completion for the given prompt. Returns a <b>lazily</b> populated stream of text
* chunks. To access more details about the individual chunks, use {@link
* #streamChatCompletionDeltas(OpenAiChatCompletionParameters)}.
* #streamChatCompletionDeltas(CreateChatCompletionRequest)}.
*
* <p>The stream should be consumed using a try-with-resources block to ensure that the underlying
* HTTP connection is closed.
Expand All @@ -173,22 +188,33 @@ public OpenAiChatCompletionOutput chatCompletion(
* @param prompt a text message.
* @return A stream of message deltas
* @throws OpenAiClientException if the request fails or if the finish reason is content_filter
* @see #streamChatCompletionDeltas(OpenAiChatCompletionParameters)
* @see #streamChatCompletionDeltas(CreateChatCompletionRequest)
*/
@Nonnull
public Stream<String> streamChatCompletion(@Nonnull final String prompt)
throws OpenAiClientException {
final OpenAiChatCompletionParameters parameters = new OpenAiChatCompletionParameters();
final CreateChatCompletionRequest parameters = new CreateChatCompletionRequest();

if (systemPrompt != null) {
parameters.addMessages(new OpenAiChatSystemMessage().setContent(systemPrompt));
parameters.addMessagesItem(
new ChatCompletionRequestSystemMessage()
.role(ChatCompletionRequestSystemMessage.RoleEnum.SYSTEM)
.content(ChatCompletionRequestSystemMessageContent.create(systemPrompt)));
}
parameters.addMessages(new OpenAiChatUserMessage().addText(prompt));
final var userMessage =
new ChatCompletionRequestUserMessage()
.role(ChatCompletionRequestUserMessage.RoleEnum.USER)
.content(ChatCompletionRequestUserMessageContent.create(prompt));
parameters.addMessagesItem(userMessage).tools(null).functions(null).parallelToolCalls(null);

return streamChatCompletionDeltas(parameters)
.map(OpenAiChatCompletionDelta.class::cast)
.peek(OpenAiClient::throwOnContentFilter)
.map(OpenAiChatCompletionDelta::getDeltaContent);
}

private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelta delta) {
private static void throwOnContentFilter(
@Nonnull final com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionDelta delta) {
final String finishReason = delta.getFinishReason();
if (finishReason != null && finishReason.equals("content_filter")) {
throw new OpenAiClientException("Content filter filtered the output.");
Expand Down Expand Up @@ -224,9 +250,9 @@ private static void throwOnContentFilter(@Nonnull final OpenAiChatCompletionDelt
*/
@Nonnull
public Stream<OpenAiChatCompletionDelta> streamChatCompletionDeltas(
@Nonnull final OpenAiChatCompletionParameters parameters) throws OpenAiClientException {
@Nonnull final CreateChatCompletionRequest parameters) throws OpenAiClientException {
warnIfUnsupportedUsage();
parameters.enableStreaming();
parameters.stream(true).streamOptions(new ChatCompletionStreamOptions().includeUsage(true));
return executeStream("/chat/completions", parameters, OpenAiChatCompletionDelta.class);
}

Expand All @@ -246,9 +272,9 @@ private void warnIfUnsupportedUsage() {
* @throws OpenAiClientException if the request fails
*/
@Nonnull
public OpenAiEmbeddingOutput embedding(@Nonnull final OpenAiEmbeddingParameters parameters)
public EmbeddingsCreate200Response embedding(@Nonnull final EmbeddingsCreateRequest parameters)
throws OpenAiClientException {
return execute("/embeddings", parameters, OpenAiEmbeddingOutput.class);
return execute("/embeddings", parameters, EmbeddingsCreate200Response.class);
}

@Nonnull
Expand Down Expand Up @@ -300,6 +326,17 @@ private <D extends StreamedDelta> Stream<D> streamRequest(
try {
final var client = ApacheHttpClient5Accessor.getHttpClient(destination);
return new ClientStreamingHandler<>(deltaType, OpenAiError.class, OpenAiClientException::new)
.objectMapper(
JACKSON
.addMixIn(
CreateChatCompletionResponse.class,
JacksonMixins.CreateChatCompletionResponseMixIn.class)
.addMixIn(
CreateChatCompletionStreamResponse.class,
JacksonMixins.CreateChatCompletionStreamResponseMixIn.class)
.addMixIn(
ChatCompletionsCreate200Response.class,
JacksonMixins.ChatCompletionCreate200ResponseMixIn.class))
.handleStreamingResponse(client.executeOpen(null, request, null));
Comment on lines 328 to 340
Copy link
Contributor Author

@newtork newtork Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Comment)

Workaround because Jackson is unable to determine deserializer by itself for

  • ChatCompletionsCreate200Response
    • CreateChatCompletionStreamResponse extends *
    • CreateChatCompletionResponse extends *

Copy link
Member

@rpanackal rpanackal Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Question)

Maybe I missed the problem. But, I don't see the reason why deduction would not work as there is an additional property prompt_filter_results in CreateChatCompletionResponse that is not present in the other.

According to my understanding, as long as there is a distinguishing property, Jackson should be able the resolve the subtype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't see the reason why deduction would not work

The actual payload for streamed responses differs significantly from the specification; it is inconsistent. Whatever annotation / setting we choose, it will be a compromise. If you find deduction working, or if you prefer a different mixin - please be my guest and change it as part of the upcoming convenience API draft.

} catch (final IOException e) {
throw new OpenAiClientException("Request to OpenAI model failed", e);
Expand Down
Loading
Loading