diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleServiceDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleServiceDemo.java index 3f57bece..2aa1308f 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleServiceDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleServiceDemo.java @@ -1,5 +1,6 @@ package io.github.sashirestela.openai.demo; +import java.util.ArrayList; import io.github.sashirestela.openai.SimpleOpenAIAnyscale; import io.github.sashirestela.openai.demo.ChatServiceDemo.Product; @@ -13,34 +14,31 @@ import io.github.sashirestela.openai.domain.chat.message.ChatMsgUser; import io.github.sashirestela.openai.domain.chat.tool.ChatFunction; import io.github.sashirestela.openai.function.FunctionExecutor; -import java.util.ArrayList; public class ChatAnyscaleServiceDemo extends AbstractDemo { public static final String MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1"; - private ChatRequest chatRequest; - public ChatAnyscaleServiceDemo(String apiKey, String model) { super(SimpleOpenAIAnyscale.builder().apiKey(apiKey).build()); chatRequest = ChatRequest.builder() - .model(model) - .message(new ChatMsgSystem("You are an expert in AI.")) - .message( - new ChatMsgUser("Write a technical article about ChatGPT, no more than 100 words.")) - .temperature(0.0) - .maxTokens(300) - .build(); + .model(model) + .message(new ChatMsgSystem("You are an expert in AI.")) + .message( + new ChatMsgUser("Write a technical article about ChatGPT, no more than 100 words.")) + .temperature(0.0) + .maxTokens(300) + .build(); } public void demoCallChatStreaming() { var futureChat = openAI.chatCompletions().createStream(chatRequest); var chatResponse = futureChat.join(); chatResponse.filter(chatResp -> chatResp.firstContent() != null) - .map(ChatResponse::firstContent) - .forEach(System.out::print); + .map(ChatResponse::firstContent) + .forEach(System.out::print); System.out.println(); } @@ -53,30 +51,30 @@ public void demoCallChatBlocking() { public void demoCallChatWithFunctions() { var functionExecutor = new FunctionExecutor(); functionExecutor.enrollFunction( - ChatFunction.builder() - .name("get_weather") - .description("Get the current weather of a location") - .functionalClass(Weather.class) - .build()); + ChatFunction.builder() + .name("get_weather") + .description("Get the current weather of a location") + .functionalClass(Weather.class) + .build()); functionExecutor.enrollFunction( - ChatFunction.builder() - .name("product") - .description("Get the product of two numbers") - .functionalClass(Product.class) - .build()); + ChatFunction.builder() + .name("product") + .description("Get the product of two numbers") + .functionalClass(Product.class) + .build()); functionExecutor.enrollFunction( - ChatFunction.builder() - .name("run_alarm") - .description("Run an alarm") - .functionalClass(RunAlarm.class) - .build()); + ChatFunction.builder() + .name("run_alarm") + .description("Run an alarm") + .functionalClass(RunAlarm.class) + .build()); var messages = new ArrayList(); messages.add(new ChatMsgUser("What is the product of 123 and 456?")); var chatRequest = ChatRequest.builder() - .model(MODEL) - .messages(messages) - .tools(functionExecutor.getToolFunctions()) - .build(); + .model(MODEL) + .messages(messages) + .tools(functionExecutor.getToolFunctions()) + .build(); var futureChat = openAI.chatCompletions().create(chatRequest); var chatResponse = futureChat.join(); var chatMessage = chatResponse.firstMessage(); @@ -85,10 +83,10 @@ public void demoCallChatWithFunctions() { messages.add(chatMessage); messages.add(new ChatMsgTool(result.toString(), chatToolCall.getId())); chatRequest = ChatRequest.builder() - .model(MODEL) - .messages(messages) - .tools(functionExecutor.getToolFunctions()) - .build(); + .model(MODEL) + .messages(messages) + .tools(functionExecutor.getToolFunctions()) + .build(); futureChat = openAI.chatCompletions().create(chatRequest); chatResponse = futureChat.join(); System.out.println(chatResponse.firstContent()); @@ -96,7 +94,7 @@ public void demoCallChatWithFunctions() { public static void main(String[] args) { var apiKey = System.getenv("ANYSCALE_API_KEY"); - // Services like Azure OpenAI don't require a model (endpoints have built-in model) + var demo = new ChatAnyscaleServiceDemo(apiKey, MODEL); demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming); diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureServiceDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureServiceDemo.java index ee6cc746..9dfd066e 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureServiceDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureServiceDemo.java @@ -1,5 +1,11 @@ package io.github.sashirestela.openai.demo; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; import io.github.sashirestela.openai.SimpleOpenAIAzure; import io.github.sashirestela.openai.demo.ChatServiceDemo.Product; @@ -16,38 +22,32 @@ import io.github.sashirestela.openai.domain.chat.message.ChatMsgUser; import io.github.sashirestela.openai.domain.chat.tool.ChatFunction; import io.github.sashirestela.openai.function.FunctionExecutor; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.Base64; -import java.util.List; public class ChatAzureServiceDemo extends AbstractDemo { private ChatRequest chatRequest; public ChatAzureServiceDemo(String baseUrl, String apiKey, String apiVersion) { super(SimpleOpenAIAzure.builder() - .apiKey(apiKey) - .baseUrl(baseUrl) - .apiVersion(apiVersion) - .build()); + .apiKey(apiKey) + .baseUrl(baseUrl) + .apiVersion(apiVersion) + .build()); chatRequest = ChatRequest.builder() - .model("N/A") - .message(new ChatMsgSystem("You are an expert in AI.")) - .message( - new ChatMsgUser("Write a technical article about ChatGPT, no more than 100 words.")) - .temperature(0.0) - .maxTokens(300) - .build(); + .model("N/A") + .message(new ChatMsgSystem("You are an expert in AI.")) + .message( + new ChatMsgUser("Write a technical article about ChatGPT, no more than 100 words.")) + .temperature(0.0) + .maxTokens(300) + .build(); } public void demoCallChatStreaming() { var futureChat = openAI.chatCompletions().createStream(chatRequest); var chatResponse = futureChat.join(); chatResponse.filter(chatResp -> chatResp.firstContent() != null) - .map(ChatResponse::firstContent) - .forEach(System.out::print); + .map(ChatResponse::firstContent) + .forEach(System.out::print); System.out.println(); } @@ -60,30 +60,30 @@ public void demoCallChatBlocking() { public void demoCallChatWithFunctions() { var functionExecutor = new FunctionExecutor(); functionExecutor.enrollFunction( - ChatFunction.builder() - .name("get_weather") - .description("Get the current weather of a location") - .functionalClass(Weather.class) - .build()); + ChatFunction.builder() + .name("get_weather") + .description("Get the current weather of a location") + .functionalClass(Weather.class) + .build()); functionExecutor.enrollFunction( - ChatFunction.builder() - .name("product") - .description("Get the product of two numbers") - .functionalClass(Product.class) - .build()); + ChatFunction.builder() + .name("product") + .description("Get the product of two numbers") + .functionalClass(Product.class) + .build()); functionExecutor.enrollFunction( - ChatFunction.builder() - .name("run_alarm") - .description("Run an alarm") - .functionalClass(RunAlarm.class) - .build()); + ChatFunction.builder() + .name("run_alarm") + .description("Run an alarm") + .functionalClass(RunAlarm.class) + .build()); var messages = new ArrayList(); messages.add(new ChatMsgUser("What is the product of 123 and 456?")); chatRequest = ChatRequest.builder() - .model("N/A") - .messages(messages) - .tools(functionExecutor.getToolFunctions()) - .build(); + .model("N/A") + .messages(messages) + .tools(functionExecutor.getToolFunctions()) + .build(); var futureChat = openAI.chatCompletions().create(chatRequest); var chatResponse = futureChat.join(); var chatMessage = chatResponse.firstMessage(); @@ -92,10 +92,10 @@ public void demoCallChatWithFunctions() { messages.add(chatMessage); messages.add(new ChatMsgTool(result.toString(), chatToolCall.getId())); chatRequest = ChatRequest.builder() - .model("N/A") - .messages(messages) - .tools(functionExecutor.getToolFunctions()) - .build(); + .model("N/A") + .messages(messages) + .tools(functionExecutor.getToolFunctions()) + .build(); futureChat = openAI.chatCompletions().create(chatRequest); chatResponse = futureChat.join(); System.out.println(chatResponse.firstContent()); @@ -103,38 +103,38 @@ public void demoCallChatWithFunctions() { public void demoCallChatWithVisionExternalImage() { var chatRequest = ChatRequest.builder() - .model("N/A") - .messages(List.of( - new ChatMsgUser(List.of( - new ContentPartText( - "What do you see in the image? Give in details in no more than 100 words."), - new ContentPartImage(new ImageUrl( - "https://upload.wikimedia.org/wikipedia/commons/e/eb/Machu_Picchu%2C_Peru.jpg")))))) - .temperature(0.0) - .maxTokens(500) - .build(); + .model("N/A") + .messages(List.of( + new ChatMsgUser(List.of( + new ContentPartText( + "What do you see in the image? Give in details in no more than 100 words."), + new ContentPartImage(new ImageUrl( + "https://upload.wikimedia.org/wikipedia/commons/e/eb/Machu_Picchu%2C_Peru.jpg")))))) + .temperature(0.0) + .maxTokens(500) + .build(); var chatResponse = openAI.chatCompletions().createStream(chatRequest).join(); chatResponse.filter(chatResp -> chatResp.firstContent() != null) - .map(chatResp -> chatResp.firstContent()) - .forEach(System.out::print); + .map(chatResp -> chatResp.firstContent()) + .forEach(System.out::print); System.out.println(); } public void demoCallChatWithVisionLocalImage() { var chatRequest = ChatRequest.builder() - .model("N/A") - .messages(List.of( - new ChatMsgUser(List.of( - new ContentPartText( - "What do you see in the image? Give in details in no more than 100 words."), - new ContentPartImage(loadImageAsBase64("src/demo/resources/machupicchu.jpg")))))) - .temperature(0.0) - .maxTokens(500) - .build(); + .model("N/A") + .messages(List.of( + new ChatMsgUser(List.of( + new ContentPartText( + "What do you see in the image? Give in details in no more than 100 words."), + new ContentPartImage(loadImageAsBase64("src/demo/resources/machupicchu.jpg")))))) + .temperature(0.0) + .maxTokens(500) + .build(); var chatResponse = openAI.chatCompletions().createStream(chatRequest).join(); chatResponse.filter(chatResp -> chatResp.firstContent() != null) - .map(chatResp -> chatResp.firstContent()) - .forEach(System.out::print); + .map(chatResp -> chatResp.firstContent()) + .forEach(System.out::print); System.out.println(); } @@ -156,14 +156,13 @@ public static void main(String[] args) { var baseUrl = System.getenv("AZURE_OPENAI_BASE_URL"); var apiKey = System.getenv("AZURE_OPENAI_API_KEY"); var apiVersion = System.getenv("AZURE_OPENAI_API_VERSION"); - // Services like Azure OpenAI don't require a model (endpoints have built-in model) - var demo = new ChatAzureServiceDemo(baseUrl, apiKey, apiVersion); + var demo = new ChatAzureServiceDemo(baseUrl, apiKey, apiVersion); demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking); if (baseUrl.contains("gpt-35-turbo")) { demo.addTitleAction("Call Chat with Functions", demo::demoCallChatWithFunctions); - } else if (baseUrl.contains("gpt-4")){ + } else if (baseUrl.contains("gpt-4")) { demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming); demo.addTitleAction("Call Chat with Vision (External image)", demo::demoCallChatWithVisionExternalImage); demo.addTitleAction("Call Chat with Vision (Local image)", demo::demoCallChatWithVisionLocalImage); diff --git a/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAI.java b/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAI.java index 965004f0..3d888273 100644 --- a/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAI.java +++ b/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAI.java @@ -1,19 +1,18 @@ package io.github.sashirestela.openai; -import io.github.sashirestela.cleverclient.CleverClient; import java.net.http.HttpClient; import java.util.Optional; + +import io.github.sashirestela.cleverclient.CleverClient; import lombok.NonNull; import lombok.Setter; - /** - * The base abstract class that all providers extend. It generates - * an implementation to the chatCompletions() interface of {@link OpenAI OpenAI} interfaces. - * It throws a "Not implemented" exception for all other interfaces + * The base abstract class that all providers extend. It generates an + * implementation to the chatCompletions() interface of {@link OpenAI OpenAI} + * interfaces. It throws a "Not implemented" exception for all other interfaces */ - public abstract class BaseSimpleOpenAI { private static final String END_OF_STREAM = "[DONE]"; @@ -24,22 +23,21 @@ public abstract class BaseSimpleOpenAI { protected OpenAI.ChatCompletions chatCompletionService; BaseSimpleOpenAI(@NonNull BaseSimpleOpenAIArgs args) { - var httpClient = - Optional.ofNullable(args.getHttpClient()).orElse(HttpClient.newHttpClient()); + var httpClient = Optional.ofNullable(args.getHttpClient()).orElse(HttpClient.newHttpClient()); this.cleverClient = CleverClient.builder() - .httpClient(httpClient) - .baseUrl(args.getBaseUrl()) - .headers(args.getHeaders()) - .endOfStream(END_OF_STREAM) - .requestInterceptor(args.getRequestInterceptor()) - .build(); + .httpClient(httpClient) + .baseUrl(args.getBaseUrl()) + .headers(args.getHeaders()) + .endOfStream(END_OF_STREAM) + .requestInterceptor(args.getRequestInterceptor()) + .build(); } /** * Throw not implemented */ public OpenAI.Audios audios() { - throw new SimpleUncheckedException("Not implemented"); + throw new UnsupportedOperationException("Not implemented"); } /** @@ -60,62 +58,62 @@ public OpenAI.ChatCompletions chatCompletions() { * Throw not implemented */ public OpenAI.Completions completions() { - throw new SimpleUncheckedException("Not implemented"); + throw new UnsupportedOperationException("Not implemented"); } /** * Throw not implemented */ public OpenAI.Embeddings embeddings() { - throw new SimpleUncheckedException("Not implemented"); + throw new UnsupportedOperationException("Not implemented"); } /** * Throw not implemented */ public OpenAI.Files files() { - throw new SimpleUncheckedException("Not implemented"); + throw new UnsupportedOperationException("Not implemented"); } /** * Throw not implemented */ public OpenAI.FineTunings fineTunings() { - throw new SimpleUncheckedException("Not implemented"); + throw new UnsupportedOperationException("Not implemented"); } /** * Throw not implemented */ public OpenAI.Images images() { - throw new SimpleUncheckedException("Not implemented"); + throw new UnsupportedOperationException("Not implemented"); } /** * Throw not implemented */ public OpenAI.Models models() { - throw new SimpleUncheckedException("Not implemented"); + throw new UnsupportedOperationException("Not implemented"); } /** * Throw not implemented */ public OpenAI.Moderations moderations() { - throw new SimpleUncheckedException("Not implemented"); + throw new UnsupportedOperationException("Not implemented"); } /** * Throw not implemented */ public OpenAI.Assistants assistants() { - throw new SimpleUncheckedException("Not implemented"); + throw new UnsupportedOperationException("Not implemented"); } /** * Throw not implemented */ public OpenAI.Threads threads() { - throw new SimpleUncheckedException("Not implemented"); + throw new UnsupportedOperationException("Not implemented"); } } diff --git a/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAIArgs.java b/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAIArgs.java index 6e60bd89..41fa6563 100644 --- a/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAIArgs.java +++ b/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAIArgs.java @@ -1,9 +1,10 @@ package io.github.sashirestela.openai; -import io.github.sashirestela.cleverclient.http.HttpRequestData; import java.net.http.HttpClient; import java.util.Map; import java.util.function.UnaryOperator; + +import io.github.sashirestela.cleverclient.http.HttpRequestData; import lombok.Builder; import lombok.Getter; import lombok.NonNull; diff --git a/src/main/java/io/github/sashirestela/openai/OpenAI.java b/src/main/java/io/github/sashirestela/openai/OpenAI.java index 1636b5d5..0e3538e1 100644 --- a/src/main/java/io/github/sashirestela/openai/OpenAI.java +++ b/src/main/java/io/github/sashirestela/openai/OpenAI.java @@ -2,8 +2,6 @@ import static io.github.sashirestela.cleverclient.util.CommonUtil.isNullOrEmpty; -import io.github.sashirestela.openai.domain.chat.tool.ChatToolChoice; -import io.github.sashirestela.openai.domain.chat.tool.ChatToolChoiceType; import java.io.InputStream; import java.util.EnumSet; import java.util.List; @@ -21,8 +19,8 @@ import io.github.sashirestela.cleverclient.annotation.Resource; import io.github.sashirestela.openai.domain.OpenAIDeletedResponse; import io.github.sashirestela.openai.domain.OpenAIGeneric; -import io.github.sashirestela.openai.domain.PageRequest; import io.github.sashirestela.openai.domain.Page; +import io.github.sashirestela.openai.domain.PageRequest; import io.github.sashirestela.openai.domain.assistant.Assistant; import io.github.sashirestela.openai.domain.assistant.AssistantFile; import io.github.sashirestela.openai.domain.assistant.AssistantRequest; @@ -45,6 +43,7 @@ import io.github.sashirestela.openai.domain.audio.AudioTranslateRequest; import io.github.sashirestela.openai.domain.chat.ChatRequest; import io.github.sashirestela.openai.domain.chat.ChatResponse; +import io.github.sashirestela.openai.domain.chat.tool.ChatToolChoiceType; import io.github.sashirestela.openai.domain.completion.CompletionRequest; import io.github.sashirestela.openai.domain.completion.CompletionResponse; import io.github.sashirestela.openai.domain.embedding.EmbeddingBase64Response; @@ -73,25 +72,6 @@ */ public interface OpenAI { - static ChatRequest updateRequest(ChatRequest chatRequest, Boolean useStream) { - var toolChoice = chatRequest.getToolChoice(); - - if (!isNullOrEmpty(chatRequest.getTools())) { - if (toolChoice == null) { - toolChoice = ChatToolChoiceType.AUTO; - } else if (!(toolChoice instanceof ChatToolChoice) && - !(toolChoice instanceof ChatToolChoiceType)) { - throw new SimpleUncheckedException( - "The field toolChoice must be ChatToolChoiceType or ChatToolChoice classes.", - null, null); - } - } - return chatRequest - .withStream(useStream) - .withToolChoice(toolChoice); - } - - /** * Turn audio into text (speech to text). * @@ -178,23 +158,6 @@ default CompletableFuture translatePlain(AudioTranslateRequest audioRequ @Multipart @POST("/translations") CompletableFuture __translatePlain(@Body AudioTranslateRequest audioRequest); - - private AudioRespFmt getResponseFormat(AudioRespFmt currValue, AudioRespFmt orDefault, String methodName) { - final var jsonEnumSet = EnumSet.of(AudioRespFmt.JSON, AudioRespFmt.VERBOSE_JSON); - final var textEnumSet = EnumSet.complementOf(jsonEnumSet); - - var isText = textEnumSet.contains(orDefault); - var requestedFormat = currValue; - if (requestedFormat != null) { - if (isText != textEnumSet.contains(requestedFormat)) { - throw new SimpleUncheckedException("Unexpected responseFormat for the method {0}.", - methodName, null); - } - } else { - requestedFormat = orDefault; - } - return requestedFormat; - } } /** @@ -1032,4 +995,29 @@ CompletableFuture> getRunStepList(@Path("threadId") String t @Path("runId") String runId, @Query PageRequest page); } + + static AudioRespFmt getResponseFormat(AudioRespFmt currValue, AudioRespFmt orDefault, String methodName) { + final var jsonEnumSet = EnumSet.of(AudioRespFmt.JSON, AudioRespFmt.VERBOSE_JSON); + final var textEnumSet = EnumSet.complementOf(jsonEnumSet); + + var isText = textEnumSet.contains(orDefault); + var requestedFormat = currValue; + if (requestedFormat != null) { + if (isText != textEnumSet.contains(requestedFormat)) { + throw new SimpleUncheckedException("Unexpected responseFormat for the method {0}.", + methodName, null); + } + } else { + requestedFormat = orDefault; + } + return requestedFormat; + } + + static ChatRequest updateRequest(ChatRequest chatRequest, Boolean useStream) { + var updatedChatRequest = chatRequest.withStream(useStream); + if (!isNullOrEmpty(chatRequest.getTools()) && chatRequest.getToolChoice() == null) { + updatedChatRequest = updatedChatRequest.withToolChoice(ChatToolChoiceType.AUTO); + } + return updatedChatRequest; + } } \ No newline at end of file diff --git a/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java b/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java index c5468c84..0019604a 100644 --- a/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java +++ b/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java @@ -3,56 +3,28 @@ import java.net.http.HttpClient; import java.util.HashMap; import java.util.Optional; + +import io.github.sashirestela.openai.support.Constant; import lombok.Builder; import lombok.NonNull; /** - * This class provides the implements additional {@link OpenAI OpenAI} interfaces - * targeting the OpenAI service. + * This class provides the implements additional {@link OpenAI OpenAI} + * interfaces targeting the OpenAI service. */ public class SimpleOpenAI extends BaseSimpleOpenAI { - public static final String OPENAI_BASE_URL = "https://api.openai.com"; - public static final String AUTHORIZATION_HEADER = "Authorization"; - public static final String ORGANIZATION_HEADER = "OpenAI-Organization"; - public static final String BEARER_AUTHORIZATION = "Bearer "; - private OpenAI.Audios audioService; private OpenAI.Completions completionService; - private OpenAI.Embeddings embeddingService; - private OpenAI.Files fileService; - private OpenAI.FineTunings fineTuningService; - private OpenAI.Images imageService; - private OpenAI.Models modelService; - private OpenAI.Moderations moderationService; - private OpenAI.Assistants assistantService; - private OpenAI.Threads threadService; - - public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs( - String apiKey, String organizationId, String baseUrl, HttpClient httpClient) { - - var headers = new HashMap(); - headers.put(AUTHORIZATION_HEADER, BEARER_AUTHORIZATION + apiKey); - if (organizationId != null) { - headers.put(ORGANIZATION_HEADER, organizationId); - } - - return BaseSimpleOpenAIArgs.builder() - .baseUrl(Optional.ofNullable(baseUrl).orElse(OPENAI_BASE_URL)) - .headers(headers) - .httpClient(httpClient) - .build(); - } - /** * Constructor used to generate a builder. * @@ -68,6 +40,22 @@ public SimpleOpenAI(@NonNull String apiKey, String organizationId, String baseUr super(prepareBaseSimpleOpenAIArgs(apiKey, organizationId, baseUrl, httpClient)); } + public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs( + String apiKey, String organizationId, String baseUrl, HttpClient httpClient) { + + var headers = new HashMap(); + headers.put(Constant.AUTHORIZATION_HEADER, Constant.BEARER_AUTHORIZATION + apiKey); + if (organizationId != null) { + headers.put(Constant.OPENAI_ORG_HEADER, organizationId); + } + + return BaseSimpleOpenAIArgs.builder() + .baseUrl(Optional.ofNullable(baseUrl).orElse(Constant.OPENAI_BASE_URL)) + .headers(headers) + .httpClient(httpClient) + .build(); + } + /** * Generates an implementation of the Audios interface to handle requests. * diff --git a/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAnyscale.java b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAnyscale.java index 70c8b580..155fb5e3 100644 --- a/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAnyscale.java +++ b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAnyscale.java @@ -1,8 +1,10 @@ package io.github.sashirestela.openai; import java.net.http.HttpClient; -import java.util.HashMap; +import java.util.Map; import java.util.Optional; + +import io.github.sashirestela.openai.support.Constant; import lombok.Builder; import lombok.NonNull; @@ -10,33 +12,28 @@ * This class provides the chatCompletion() service for the Anyscale provider */ public class SimpleOpenAIAnyscale extends BaseSimpleOpenAI { - public static final String DEFAULT_BASE_URL = "https://api.endpoints.anyscale.com"; - - public static final String AUTHORIZATION_HEADER = "Authorization"; - public static final String BEARER_AUTHORIZATION = "Bearer "; /** * Constructor used to generate a builder. * - * @param apiKey Identifier to be used for authentication. Mandatory. - * @param baseUrl Host's url - * @param httpClient A {@link java.net.http.HttpClient HttpClient} object. - * One is created by default if not provided. Optional. + * @param apiKey Identifier to be used for authentication. Mandatory. + * @param baseUrl Host's url. Optional. + * @param httpClient A {@link java.net.http.HttpClient HttpClient} object. + * One is created by default if not provided. Optional. */ - public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, String baseUrl, HttpClient httpClient) { - baseUrl = Optional.ofNullable(baseUrl).orElse(DEFAULT_BASE_URL); - var headers = new HashMap(); - headers.put(AUTHORIZATION_HEADER, BEARER_AUTHORIZATION + apiKey); - - return BaseSimpleOpenAIArgs.builder() - .baseUrl(baseUrl) - .headers(headers) - .httpClient(httpClient) - .build(); - } - @Builder public SimpleOpenAIAnyscale(@NonNull String apiKey, String baseUrl, HttpClient httpClient) { super(prepareBaseSimpleOpenAIArgs(apiKey, baseUrl, httpClient)); } -} + + public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, String baseUrl, + HttpClient httpClient) { + var headers = Map.of(Constant.AUTHORIZATION_HEADER, Constant.BEARER_AUTHORIZATION + apiKey); + + return BaseSimpleOpenAIArgs.builder() + .baseUrl(Optional.ofNullable(baseUrl).orElse(Constant.ANYSCALE_BASE_URL)) + .headers(headers) + .httpClient(httpClient) + .build(); + } +} \ No newline at end of file diff --git a/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java index e9c9c223..ff472638 100644 --- a/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java +++ b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java @@ -1,42 +1,59 @@ package io.github.sashirestela.openai; -import io.github.sashirestela.cleverclient.http.HttpRequestData; -import io.github.sashirestela.cleverclient.support.ContentType; import java.net.http.HttpClient; import java.util.Map; import java.util.function.UnaryOperator; + +import io.github.sashirestela.cleverclient.http.HttpRequestData; +import io.github.sashirestela.cleverclient.support.ContentType; +import io.github.sashirestela.openai.support.Constant; import lombok.Builder; import lombok.NonNull; /** - * This class provides the chatCompletion() service for the Azure OpenAI provider - * Note that each instance of SimpleOpenAIAzure is linked to a single specific model. - * The capabilities of the model determine which chatCompletion() methods are available. + * This class provides the chatCompletion() service for the Azure OpenAI + * provider. Note that each instance of SimpleOpenAIAzure is linked to a single + * specific model. The capabilities of the model determine which + * chatCompletion() methods are available. */ public class SimpleOpenAIAzure extends BaseSimpleOpenAI { - public static final String API_KEY_HEADER = "api-key"; - public static final String API_VERSION = "api-version"; - - private static final String ENDPOINT_VERSION_REGEX = "(\\/v\\d+\\.*\\d*)"; - private static final String MODEL_REGEX = ",?\"model\":\"[^\"]*\",?"; - - private static final String EMPTY_REGEX = "\"\""; - private static final String QUOTED_COMMA = "\",\""; + /** + * Constructor used to generate a builder. + * + * @param apiKey Identifier to be used for authentication. Mandatory. + * @param baseUrl The URL of the Azure OpenAI deployment. Mandatory. + * @param apiVersion Azure OpenAI API version. See: + * Azure + * OpenAI API versioning. Mandatory. + * @param httpClient A {@link HttpClient HttpClient} object. + * One is created by default if not provided. Optional. + */ + @Builder + public SimpleOpenAIAzure(@NonNull String apiKey, @NonNull String baseUrl, @NonNull String apiVersion, + HttpClient httpClient) { + super(prepareBaseSimpleOpenAIArgs(apiKey, baseUrl, apiVersion, httpClient)); + } - private static final String MODEL_LITERAL = "model"; + public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, String baseUrl, String apiVersion, + HttpClient httpClient) { - public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, String baseUrl, String apiVersion, HttpClient httpClient) { + var headers = Map.of(Constant.AZURE_APIKEY_HEADER, apiKey); - var headers = Map.of(API_KEY_HEADER, apiKey); + UnaryOperator requestInterceptor = request -> { + final String VERSION_REGEX = "(\\/v\\d+\\.*\\d*)"; + final String MODEL_REGEX = ",?\"model\":\"[^\"]*\",?"; + final String EMPTY_REGEX = "\"\""; + final String QUOTED_COMMA = "\",\""; + final String MODEL_LITERAL = "model"; - var requestInterceptor = (UnaryOperator) request -> { var url = request.getUrl(); var contentType = request.getContentType(); var body = request.getBody(); - url += (url.contains("?") ? "&" : "?") + API_VERSION + "=" + apiVersion; - url = url.replaceFirst(ENDPOINT_VERSION_REGEX, ""); + url += (url.contains("?") ? "&" : "?") + Constant.AZURE_API_VERSION + "=" + apiVersion; + url = url.replaceFirst(VERSION_REGEX, ""); request.setUrl(url); if (contentType != null) { @@ -59,25 +76,10 @@ public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, St }; return BaseSimpleOpenAIArgs.builder() - .baseUrl(baseUrl) - .headers(headers) - .httpClient(httpClient) - .requestInterceptor(requestInterceptor) - .build(); - } - - /** - * Constructor used to generate a builder. - * - * @param apiKey Identifier to be used for authentication. Mandatory. - * @param baseUrl The URL of the Azure OpenAI deployment. Mandatory. - * @param apiVersion Azure OpenAI API version. See: - * Azure OpenAI API versioning - * @param httpClient A {@link HttpClient HttpClient} object. - * One is created by default if not provided. Optional. - */ - @Builder - public SimpleOpenAIAzure(@NonNull String apiKey, @NonNull String baseUrl, @NonNull String apiVersion, HttpClient httpClient) { - super(prepareBaseSimpleOpenAIArgs(apiKey, baseUrl, apiVersion, httpClient)); + .baseUrl(baseUrl) + .headers(headers) + .httpClient(httpClient) + .requestInterceptor(requestInterceptor) + .build(); } } diff --git a/src/main/java/io/github/sashirestela/openai/SimpleUncheckedException.java b/src/main/java/io/github/sashirestela/openai/SimpleUncheckedException.java index 3451e4d4..387baa6c 100644 --- a/src/main/java/io/github/sashirestela/openai/SimpleUncheckedException.java +++ b/src/main/java/io/github/sashirestela/openai/SimpleUncheckedException.java @@ -5,6 +5,10 @@ public class SimpleUncheckedException extends RuntimeException { + public SimpleUncheckedException(String message) { + super(message); + } + public SimpleUncheckedException(String message, Object... parameters) { super(MessageFormat.format(message, Arrays.copyOfRange(parameters, 0, parameters.length - 1)), (Throwable) parameters[parameters.length - 1]); diff --git a/src/main/java/io/github/sashirestela/openai/domain/chat/ChatRequest.java b/src/main/java/io/github/sashirestela/openai/domain/chat/ChatRequest.java index feef27bc..8aa15437 100644 --- a/src/main/java/io/github/sashirestela/openai/domain/chat/ChatRequest.java +++ b/src/main/java/io/github/sashirestela/openai/domain/chat/ChatRequest.java @@ -11,6 +11,8 @@ import io.github.sashirestela.openai.SimpleUncheckedException; import io.github.sashirestela.openai.domain.chat.message.ChatMsg; import io.github.sashirestela.openai.domain.chat.tool.ChatTool; +import io.github.sashirestela.openai.domain.chat.tool.ChatToolChoice; +import io.github.sashirestela.openai.domain.chat.tool.ChatToolChoiceType; import lombok.Builder; import lombok.Getter; import lombok.NonNull; @@ -22,16 +24,20 @@ @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) public class ChatRequest { - @NonNull private String model; - @NonNull private List messages; + @NonNull + private String model; + @NonNull + private List messages; private ChatRespFmt responseFormat; private Integer seed; private List tools; - @With private Object toolChoice; + @With + private Object toolChoice; private Double temperature; private Double topP; private Integer n; - @With private Boolean stream; + @With + private Boolean stream; private Object stop; private Integer maxTokens; private Double presencePenalty; @@ -46,6 +52,13 @@ public ChatRequest(@NonNull String model, @NonNull @Singular List messa Integer seed, @Singular List tools, Object toolChoice, Double temperature, Double topP, Integer n, Boolean stream, Object stop, Integer maxTokens, Double presencePenalty, Double frequencyPenalty, Map logitBias, String user, Boolean logprobs, Integer topLogprobs) { + if (toolChoice != null && + !(toolChoice instanceof ChatToolChoiceType) + && !(toolChoice instanceof ChatToolChoice)) { + throw new SimpleUncheckedException( + "The field toolChoice must be ChatToolChoiceType or ChatToolChoice classes.", + null, null); + } if (stop != null && !(stop instanceof String) && !(stop instanceof List && ((List) stop).get(0) instanceof String && ((List) stop).size() <= 4)) { throw new SimpleUncheckedException( diff --git a/src/main/java/io/github/sashirestela/openai/support/Constant.java b/src/main/java/io/github/sashirestela/openai/support/Constant.java new file mode 100644 index 00000000..48ef8867 --- /dev/null +++ b/src/main/java/io/github/sashirestela/openai/support/Constant.java @@ -0,0 +1,16 @@ +package io.github.sashirestela.openai.support; + +public class Constant { + + public static final String AUTHORIZATION_HEADER = "Authorization"; + public static final String BEARER_AUTHORIZATION = "Bearer "; + + public static final String OPENAI_BASE_URL = "https://api.openai.com"; + public static final String OPENAI_ORG_HEADER = "OpenAI-Organization"; + + public static final String ANYSCALE_BASE_URL = "https://api.endpoints.anyscale.com"; + + public static final String AZURE_APIKEY_HEADER = "api-key"; + public static final String AZURE_API_VERSION = "api-version"; + +} \ No newline at end of file diff --git a/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAnyscaleTest.java b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAnyscaleTest.java index 24cc077a..710b5f81 100644 --- a/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAnyscaleTest.java +++ b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAnyscaleTest.java @@ -1,46 +1,61 @@ package io.github.sashirestela.openai; -import static io.github.sashirestela.openai.SimpleOpenAIAnyscale.AUTHORIZATION_HEADER; -import static io.github.sashirestela.openai.SimpleOpenAIAnyscale.BEARER_AUTHORIZATION; -import static io.github.sashirestela.openai.SimpleOpenAIAnyscale.DEFAULT_BASE_URL; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.net.http.HttpClient; + import org.junit.jupiter.api.Test; +import io.github.sashirestela.openai.support.Constant; + class SimpleOpenAIAnyscaleTest { @Test void shouldPrepareBaseOpenSimpleAIArgsCorrectlyWithCustomBaseURL() { - var args = SimpleOpenAIAnyscale.prepareBaseSimpleOpenAIArgs( - "the-api-key", - "https://example.org", - HttpClient.newHttpClient()); + var args = SimpleOpenAIAnyscale.prepareBaseSimpleOpenAIArgs("the-api-key", "https://example.org", + HttpClient.newHttpClient()); assertEquals("https://example.org", args.getBaseUrl()); assertEquals(1, args.getHeaders().size()); - assertEquals(BEARER_AUTHORIZATION + "the-api-key", args.getHeaders().get(AUTHORIZATION_HEADER)); + assertEquals(Constant.BEARER_AUTHORIZATION + "the-api-key", + args.getHeaders().get(Constant.AUTHORIZATION_HEADER)); assertNotNull(args.getHttpClient()); - - // No request interceptor for SimpleOpenAIAnyscale assertNull(args.getRequestInterceptor()); } @Test - void shouldPrepareBaseOpenSimpleAIArgsCorrectlyWithDefaultBaseURL() { - var args = SimpleOpenAIAnyscale.prepareBaseSimpleOpenAIArgs( - "the-api-key", - null, - HttpClient.newHttpClient()); + void shouldPrepareBaseOpenSimpleAIArgsCorrectlyWithOnlyApiKey() { + var args = SimpleOpenAIAnyscale.prepareBaseSimpleOpenAIArgs("the-api-key", null, null); - assertEquals(SimpleOpenAIAnyscale.DEFAULT_BASE_URL, args.getBaseUrl()); + assertEquals(Constant.ANYSCALE_BASE_URL, args.getBaseUrl()); assertEquals(1, args.getHeaders().size()); - assertEquals(BEARER_AUTHORIZATION + "the-api-key", args.getHeaders().get(AUTHORIZATION_HEADER)); - assertNotNull(args.getHttpClient()); - - // No request interceptor for SimpleOpenAIAnyscale + assertEquals(Constant.BEARER_AUTHORIZATION + "the-api-key", + args.getHeaders().get(Constant.AUTHORIZATION_HEADER)); + assertNull(args.getHttpClient()); assertNull(args.getRequestInterceptor()); } + @Test + void shouldThrownExceptionWhenCallingUnimplementedMethods() { + var openAI = SimpleOpenAIAnyscale.builder() + .apiKey("api-key-test") + .build(); + Runnable[] callingData = { + openAI::audios, + openAI::completions, + openAI::embeddings, + openAI::files, + openAI::fineTunings, + openAI::images, + openAI::models, + openAI::moderations, + openAI::assistants, + openAI::threads + }; + for (Runnable calling : callingData) { + assertThrows(UnsupportedOperationException.class, () -> calling.run()); + } + }; } diff --git a/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAzureTest.java b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAzureTest.java index a25cb1ba..c7fd5ce6 100644 --- a/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAzureTest.java +++ b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAzureTest.java @@ -1,57 +1,109 @@ package io.github.sashirestela.openai; -import static io.github.sashirestela.openai.SimpleOpenAIAzure.API_KEY_HEADER; -import static io.github.sashirestela.openai.SimpleOpenAIAzure.API_VERSION; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; -import io.github.sashirestela.cleverclient.http.HttpRequestData; -import io.github.sashirestela.cleverclient.support.ContentType; import java.net.http.HttpClient; +import java.util.HashMap; import java.util.Map; + import org.junit.jupiter.api.Test; -class SimpleOpenAIAzureTest { +import io.github.sashirestela.cleverclient.http.HttpRequestData; +import io.github.sashirestela.cleverclient.support.ContentType; +import io.github.sashirestela.openai.support.Constant; +class SimpleOpenAIAzureTest { @Test void shouldPrepareBaseOpenSimpleAIArgsCorrectly() { - var args = SimpleOpenAIAzure.prepareBaseSimpleOpenAIArgs( - "the-api-key", - "https://example.org", - "12-34-5678", - HttpClient.newHttpClient()); + var args = SimpleOpenAIAzure.prepareBaseSimpleOpenAIArgs("the-api-key", "https://example.org", "12-34-5678", + HttpClient.newHttpClient()); assertEquals("https://example.org", args.getBaseUrl()); assertEquals(1, args.getHeaders().size()); - assertEquals("the-api-key", args.getHeaders().get(API_KEY_HEADER)); + assertEquals("the-api-key", args.getHeaders().get(Constant.AZURE_APIKEY_HEADER)); assertNotNull(args.getHttpClient()); assertNotNull(args.getRequestInterceptor()); } @Test - void shouldInterceptUrlCorrectly() { + void shouldInterceptUrlCorrectlyWhenBodyIsJson() { var request = HttpRequestData.builder() - .url("https://example.org/v1/endpoint") - .contentType(ContentType.APPLICATION_JSON) - .headers(Map.of(API_KEY_HEADER, "the-api-key")) - .body("{\"model\":\"model1\"}") - .build(); + .url("https://example.org/v1/endpoint") + .contentType(ContentType.APPLICATION_JSON) + .headers(Map.of(Constant.AZURE_APIKEY_HEADER, "the-api-key")) + .body("{\"model\":\"model1\"}") + .build(); var expectedRequest = HttpRequestData.builder() - .url("https://example.org/endpoint?" + API_VERSION + "=12-34-5678") - .contentType(ContentType.APPLICATION_JSON) - .headers(Map.of(API_KEY_HEADER, "the-api-key")) - .body("{}") - .build(); + .url("https://example.org/endpoint?" + Constant.AZURE_API_VERSION + "=12-34-5678") + .contentType(ContentType.APPLICATION_JSON) + .headers(Map.of(Constant.AZURE_APIKEY_HEADER, "the-api-key")) + .body("{}") + .build(); var args = SimpleOpenAIAzure.prepareBaseSimpleOpenAIArgs( - "the-api-key", - "https://example.org", - "12-34-5678", - null); + "the-api-key", + "https://example.org", + "12-34-5678", + null); var actualRequest = args.getRequestInterceptor().apply(request); - assertEquals(expectedRequest.getUrl() , actualRequest.getUrl()); + assertEquals(expectedRequest.getUrl(), actualRequest.getUrl()); assertEquals(expectedRequest.getContentType(), actualRequest.getContentType()); assertEquals(expectedRequest.getHeaders(), actualRequest.getHeaders()); assertEquals(expectedRequest.getBody(), actualRequest.getBody()); } -} + + @Test + void shouldInterceptUrlCorrectlyWhenBodyIsMap() { + Map data = new HashMap<>(); + data.put("model", "model1"); + + var request = HttpRequestData.builder() + .url("https://example.org/v1/endpoint") + .contentType(ContentType.MULTIPART_FORMDATA) + .headers(Map.of(Constant.AZURE_APIKEY_HEADER, "the-api-key")) + .body(data) + .build(); + var expectedRequest = HttpRequestData.builder() + .url("https://example.org/endpoint?" + Constant.AZURE_API_VERSION + "=12-34-5678") + .contentType(ContentType.MULTIPART_FORMDATA) + .headers(Map.of(Constant.AZURE_APIKEY_HEADER, "the-api-key")) + .body(Map.of()) + .build(); + var args = SimpleOpenAIAzure.prepareBaseSimpleOpenAIArgs( + "the-api-key", + "https://example.org", + "12-34-5678", + null); + var actualRequest = args.getRequestInterceptor().apply(request); + assertEquals(expectedRequest.getUrl(), actualRequest.getUrl()); + assertEquals(expectedRequest.getContentType(), actualRequest.getContentType()); + assertEquals(expectedRequest.getHeaders(), actualRequest.getHeaders()); + assertEquals(expectedRequest.getBody(), actualRequest.getBody()); + } + + @Test + void shouldThrownExceptionWhenCallingUnimplementedMethods() { + var openAI = SimpleOpenAIAzure.builder() + .apiKey("apiKey") + .baseUrl("baseUrl") + .apiVersion("apiVersion") + .build(); + Runnable[] callingData = { + openAI::audios, + openAI::completions, + openAI::embeddings, + openAI::files, + openAI::fineTunings, + openAI::images, + openAI::models, + openAI::moderations, + openAI::assistants, + openAI::threads + }; + for (Runnable calling : callingData) { + assertThrows(UnsupportedOperationException.class, () -> calling.run()); + } + }; +} \ No newline at end of file diff --git a/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java b/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java index e7fd73bf..2641bffb 100644 --- a/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java +++ b/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java @@ -1,9 +1,5 @@ package io.github.sashirestela.openai; -import static io.github.sashirestela.openai.SimpleOpenAI.AUTHORIZATION_HEADER; -import static io.github.sashirestela.openai.SimpleOpenAI.BEARER_AUTHORIZATION; -import static io.github.sashirestela.openai.SimpleOpenAI.OPENAI_BASE_URL; -import static io.github.sashirestela.openai.SimpleOpenAI.ORGANIZATION_HEADER; import static java.util.concurrent.CompletableFuture.completedFuture; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -27,6 +23,7 @@ import io.github.sashirestela.cleverclient.CleverClient; import io.github.sashirestela.openai.domain.chat.ChatRequest; +import io.github.sashirestela.openai.support.Constant; class SimpleOpenAITest { @@ -35,36 +32,30 @@ class SimpleOpenAITest { @Test void shouldPrepareBaseOpenSimpleAIArgsCorrectly() { - - var args = SimpleOpenAI.prepareBaseSimpleOpenAIArgs( - "the-api-key", - "orgId", - "https://example.org", - HttpClient.newHttpClient()); + var args = SimpleOpenAI.prepareBaseSimpleOpenAIArgs("the-api-key", "orgId", "https://example.org", + HttpClient.newHttpClient()); assertEquals("https://example.org", args.getBaseUrl()); assertEquals(2, args.getHeaders().size()); - assertEquals(BEARER_AUTHORIZATION + "the-api-key", args.getHeaders().get(AUTHORIZATION_HEADER)); - assertEquals("orgId", args.getHeaders().get(ORGANIZATION_HEADER)); + assertEquals(Constant.BEARER_AUTHORIZATION + "the-api-key", + args.getHeaders().get(Constant.AUTHORIZATION_HEADER)); + assertEquals("orgId", args.getHeaders().get(Constant.OPENAI_ORG_HEADER)); assertNotNull(args.getHttpClient()); - - // No request interceptor for SimpleOpenAI assertNull(args.getRequestInterceptor()); } + @Test void shouldPrepareBaseOpenSimpleAIArgsCorrectlyWithOnlyApiKey() { var args = SimpleOpenAI.prepareBaseSimpleOpenAIArgs("the-api-key", null, null, null); - assertEquals(OPENAI_BASE_URL, args.getBaseUrl()); + assertEquals(Constant.OPENAI_BASE_URL, args.getBaseUrl()); assertEquals(1, args.getHeaders().size()); - assertEquals(BEARER_AUTHORIZATION + "the-api-key", args.getHeaders().get(AUTHORIZATION_HEADER)); - assertNotNull(args.getHttpClient()); - - // No request interceptor for SimpleOpenAI + assertEquals(Constant.BEARER_AUTHORIZATION + "the-api-key", + args.getHeaders().get(Constant.AUTHORIZATION_HEADER)); + assertNull(args.getHttpClient()); assertNull(args.getRequestInterceptor()); } - @Test @SuppressWarnings("unchecked") void shouldNotDuplicateContentTypeHeaderWhenCallingSimpleOpenAI() { diff --git a/src/test/java/io/github/sashirestela/openai/domain/chat/ChatDomainTest.java b/src/test/java/io/github/sashirestela/openai/domain/chat/ChatDomainTest.java index a5c7ec57..c3d010fd 100644 --- a/src/test/java/io/github/sashirestela/openai/domain/chat/ChatDomainTest.java +++ b/src/test/java/io/github/sashirestela/openai/domain/chat/ChatDomainTest.java @@ -7,9 +7,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; -import io.github.sashirestela.openai.OpenAI; -import io.github.sashirestela.openai.OpenAI.ChatCompletions; -import io.github.sashirestela.openai.domain.chat.tool.ChatTool; import java.io.IOException; import java.net.http.HttpClient; import java.util.List; @@ -21,6 +18,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import io.github.sashirestela.openai.OpenAI; import io.github.sashirestela.openai.SimpleOpenAI; import io.github.sashirestela.openai.SimpleUncheckedException; import io.github.sashirestela.openai.domain.DomainTestingHelper; @@ -208,29 +206,25 @@ void shouldCreateChatRequestWhenToolChoiceIsRightClass() { @Test void shouldUpdateChatRequestWithAutoToolChoiceWhenToolsAreProvidedWithoutToolChoice() { var charRequest = ChatRequest.builder() - .model("model") - .message(new ChatMsgUser("content")) - .tools(functionExecutor.getToolFunctions()) - .build(); + .model("model") + .message(new ChatMsgUser("content")) + .tools(functionExecutor.getToolFunctions()) + .build(); assertNull(charRequest.getToolChoice()); var updatedChatRequest = OpenAI.updateRequest(charRequest, Boolean.TRUE); assertEquals(ChatToolChoiceType.AUTO, updatedChatRequest.getToolChoice()); } - @Test void shouldThrownExceptionWhenCreatingChatRequestWithToolChoiceWrongClass() { - var charRequest = ChatRequest.builder() - .model("model") - .message(new ChatMsgUser("content")) - .tools(functionExecutor.getToolFunctions()) - .toolChoice("wrong value") - .build(); - - var exception = assertThrows(SimpleUncheckedException.class, () -> OpenAI.updateRequest(charRequest, Boolean.TRUE)); + var chatRequestBuilder = ChatRequest.builder() + .model("model") + .message(new ChatMsgUser("My Content")) + .toolChoice("wrong value"); + var exception = assertThrows(SimpleUncheckedException.class, () -> chatRequestBuilder.build()); var actualErrorMessage = exception.getMessage(); - var expectedErrorMessage = "The field toolChoice must be ChatToolChoiceType or ChatToolChoice classes."; - assertEquals(expectedErrorMessage, actualErrorMessage); + var expectedErrorMessge = "The field toolChoice must be ChatToolChoiceType or ChatToolChoice classes."; + assertEquals(expectedErrorMessge, actualErrorMessage); } @Test