diff --git a/README.md b/README.md index b2b4569..c6cbc0d 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,10 @@ A Java library to use the OpenAI Api in the simplest possible way. - [Chat Completion with Structured Outputs](#chat-completion-with-structured-outputs) - [Chat Completion Conversation Example](#chat-completion-conversation-example) - [Assistant v2 Conversation Example](#assistant-v2-conversation-example) - - [Realtime Conversation Example](#realtime-conversation-example) **UPDATED** + - [Realtime Conversation Example](#realtime-conversation-example) - [Exception Handling](#-exception-handling) - [Support for Additional OpenAI Providers](#-support-for-additional-openai-providers) + - [Mistral API](#mistral-api) **NEW** - [Azure OpenAI](#azure-openai) - [Anyscale](#anyscale) - [Run Examples](#-run-examples) @@ -57,8 +58,8 @@ Full support for most of the OpenAI services: * Image (Generate, Edit, Variation) * Models (List) * Moderation (Check Harmful Text) -* Realtime Beta (Speech-to-Speech Conversation, Multimodality, Function Calling) **UPDATED** -* Session Token (Create Ephemeral Tokens) **NEW** +* Realtime Beta (Speech-to-Speech Conversation, Multimodality, Function Calling) +* Session Token (Create Ephemeral Tokens) * Upload (Upload Large Files in Parts) * Assistants Beta v2 (Assistants, Threads, Messages, Runs, Steps, Vector Stores, Streaming, Function Calling, Vision, Structured Outputs) @@ -966,6 +967,20 @@ This exception handling mechanism allows you to handle API errors and provide fe ## ✴ Support for Additional OpenAI Providers Simple-OpenAI can be used with additional providers that are compatible with the OpenAI API. At this moment, there is support for the following additional providers: +### Mistral API +[Mistral API](https://docs.mistral.ai/getting-started/quickstart/) is supported by Simple-OpenAI. We can use the class `SimpleOpenAIMistral` to start using this provider. +```java +var openai = SimpleOpenAIMistral.builder() + .apiKey(System.getenv("MISTRAL_API_KEY")) + //.baseUrl(customUrl) Optionally you could pass a custom baseUrl + //.httpClient(customHttpClient) Optionally you could pass a custom HttpClient + .build(); +``` +Currently we are supporting the following services: +- `chatCompletionService` (text generation, streaming, function calling, vision) +- `embeddingService` (float format) +- `modelService` (list, detail, delete) + ### Azure OpenAI [Azure OpenIA](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) is supported by Simple-OpenAI. We can use the class `SimpleOpenAIAzure` to start using this provider. ```java @@ -1006,7 +1021,7 @@ Examples for each OpenAI service have been created in the folder [demo](https:// ``` mvn clean install ``` -* Create an environment variable for your OpenAI Api Key: +* Create an environment variable for your OpenAI Api Key (the variable varies according to the OpenAI provider that we want to run): ``` export OPENAI_API_KEY= ``` @@ -1020,33 +1035,7 @@ Examples for each OpenAI service have been created in the folder [demo](https:// ``` Where: - * `````` Is mandatory and must be one of the values: - * Audio - * Batch - * Chat - * Completion - * Embedding - * Exception - * File - * Finetuning - * Image - * Model - * Moderation - * Realtime - * SessionToken - * Upload - * Conversation - * AssistantV2 - * ThreadV2 - * ThreadMessageV2 - * ThreadRunV2 - * ThreadRunStepV2 - * VectorStoreV2 - * VectorStoreFileV2 - * VectorStoreFileBatchV2 - * ConversationV2 - * ChatAnyscale - * ChatAzure + * `````` Is mandatory and must be one of the Java files in the folder demo without the suffix `Demo`, for example: _Audio, Chat, ChatMistral, Realtime, AssistantV2, Conversation, ConversationV2, etc._ * ```[debug]``` Is optional and creates the ```demo.log``` file where you can see log details for each execution. * For example, to run the chat demo with a log file: ```./rundemo.sh Chat debug``` diff --git a/pom.xml b/pom.xml index 08ba681..f2c7905 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.github.sashirestela simple-openai - 3.13.0 + 3.14.0 jar simple-openai @@ -47,19 +47,24 @@ - + UTF-8 + 3.6.3 11 full - - 2.0.16 - 1.6.2 - 1.2.2 + + + 1.6.3 + 1.2.3 1.18.36 2.18.2 4.37.0 - 5.11.4 - 5.14.2 + 2.0.16 + + + [5.11.0,6.0.0) + [5.14.0,6.0.0) + 3.13.0 3.5.0 @@ -73,6 +78,7 @@ 3.2.7 1.7.0 2.43.0 + 3.8.1 @@ -81,6 +87,12 @@ true + + + org.slf4j + slf4j-simple + + @@ -192,6 +204,7 @@ + io.github.sashirestela cleverclient @@ -202,17 +215,6 @@ slimvalidator ${slimvalidator.version} - - org.slf4j - slf4j-api - ${slf4j.version} - - - org.slf4j - slf4j-simple - ${slf4j.version} - true - org.projectlombok lombok @@ -224,6 +226,8 @@ jackson-databind ${jackson.version} + + com.github.victools jsonschema-generator @@ -240,6 +244,21 @@ jsonschema-module-jackson ${json.schema.version} + + + + org.slf4j + slf4j-api + ${slf4j.version} + + + org.slf4j + slf4j-simple + ${slf4j.version} + provided + + + org.junit.jupiter junit-jupiter @@ -257,17 +276,12 @@ mockito-junit-jupiter ${mockito.version} test - - - org.junit.jupiter - junit-jupiter-api - - + io.github.sashirestela cleverclient @@ -284,6 +298,8 @@ com.fasterxml.jackson.core jackson-databind + + com.github.victools jsonschema-generator @@ -292,6 +308,14 @@ com.github.victools jsonschema-module-jackson + + + + org.slf4j + slf4j-api + + + org.junit.jupiter junit-jupiter @@ -304,14 +328,6 @@ org.mockito mockito-junit-jupiter - - org.slf4j - slf4j-api - - - org.slf4j - slf4j-simple - @@ -321,6 +337,19 @@ maven-compiler-plugin ${compiler.version} + + org.apache.maven.plugins + maven-dependency-plugin + ${dependency.version} + + + analyze + + analyze-only + + + + org.apache.maven.plugins maven-enforcer-plugin diff --git a/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java index c0ea140..1e9b38d 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java @@ -3,7 +3,7 @@ import io.github.sashirestela.openai.SimpleOpenAI; import io.github.sashirestela.openai.SimpleOpenAIAnyscale; import io.github.sashirestela.openai.SimpleOpenAIAzure; -import io.github.sashirestela.openai.service.ChatCompletionServices; +import io.github.sashirestela.openai.SimpleOpenAIMistral; import lombok.NonNull; import java.util.ArrayList; @@ -14,8 +14,7 @@ public abstract class AbstractDemo { protected SimpleOpenAI openAI; protected SimpleOpenAIAzure openAIAzure; protected SimpleOpenAIAnyscale openAIAnyscale; - - protected ChatCompletionServices chatProvider; + protected SimpleOpenAIMistral openAIMistral; private static List titleActions = new ArrayList<>(); private static final int TIMES = 80; @@ -31,7 +30,6 @@ protected AbstractDemo(String provider) { .apiKey(System.getenv("OPENAI_API_KEY")) .organizationId(System.getenv("OPENAI_ORGANIZATION_ID")) .build(); - chatProvider = openAI; break; case "azure": openAIAzure = SimpleOpenAIAzure.builder() @@ -39,13 +37,16 @@ protected AbstractDemo(String provider) { .apiVersion(System.getenv("AZURE_OPENAI_API_VERSION")) .baseUrl(System.getenv("AZURE_OPENAI_BASE_URL")) .build(); - chatProvider = openAIAzure; break; case "anyscale": openAIAnyscale = SimpleOpenAIAnyscale.builder() .apiKey(System.getenv("ANYSCALE_API_KEY")) .build(); - chatProvider = openAIAnyscale; + break; + case "mistral": + openAIMistral = SimpleOpenAIMistral.builder() + .apiKey(System.getenv("MISTRAL_API_KEY")) + .build(); break; default: break; diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleDemo.java index 812fe29..b0144a1 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleDemo.java @@ -2,12 +2,13 @@ public class ChatAnyscaleDemo extends ChatDemo { - public ChatAnyscaleDemo(String provider, String model) { - super(provider, model, null); + public ChatAnyscaleDemo(String model) { + super("anyscale", model, null); + this.chatProvider = this.openAIAnyscale; } public static void main(String[] args) { - var demo = new ChatAnyscaleDemo("anyscale", "mistralai/Mixtral-8x7B-Instruct-v0.1"); + var demo = new ChatAnyscaleDemo("mistralai/Mixtral-8x7B-Instruct-v0.1"); demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming); demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking); diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureDemo.java index 4f78418..a4dae4f 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureDemo.java @@ -12,8 +12,9 @@ public class ChatAzureDemo extends ChatDemo { - public ChatAzureDemo(String provider, String model) { - super(provider, model, null); + public ChatAzureDemo() { + super("azure", "N/A", null); + this.chatProvider = this.openAIAzure; } @Override @@ -51,7 +52,7 @@ public void demoCallChatWithVisionLocalImage() { } public static void main(String[] args) { - var demo = new ChatAzureDemo("azure", "N/A"); + var demo = new ChatAzureDemo(); demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming); demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking); diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ChatDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ChatDemo.java index 8d81f99..4e4e599 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/ChatDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/ChatDemo.java @@ -26,6 +26,7 @@ import io.github.sashirestela.openai.domain.chat.ChatRequest; import io.github.sashirestela.openai.domain.chat.ChatRequest.Audio; import io.github.sashirestela.openai.domain.chat.ChatRequest.Modality; +import io.github.sashirestela.openai.service.ChatCompletionServices; import io.github.sashirestela.openai.support.Base64Util; import io.github.sashirestela.openai.support.Base64Util.MediaType; @@ -34,14 +35,21 @@ import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.TimeUnit; public class ChatDemo extends AbstractDemo { protected ChatRequest chatRequest; protected String model; protected String modelAudio; + protected long sleepSeconds; + protected ChatCompletionServices chatProvider; - public ChatDemo(String provider, String model, String modelAudio) { + public ChatDemo(String model, String modelAudio) { + this("standard", model, modelAudio); + } + + protected ChatDemo(String provider, String model, String modelAudio) { super(provider); this.model = model; this.modelAudio = modelAudio; @@ -52,12 +60,14 @@ public ChatDemo(String provider, String model, String modelAudio) { .temperature(0.0) .maxTokens(300) .build(); + this.sleepSeconds = 0; // No sleep by default + this.chatProvider = this.openAI; } public void demoCallChatStreaming() { var futureChat = chatProvider.chatCompletions().createStream(chatRequest); var chatResponse = futureChat.join(); - chatResponse.forEach(ChatDemo::processResponseChunk); + chatResponse.forEach(this::processResponseChunk); } public void demoCallChatBlocking() { @@ -108,13 +118,14 @@ public void demoCallChatWithFunctions() { .messages(messages) .tools(functionExecutor.getToolFunctions()) .build(); + sleep(); futureChat = chatProvider.chatCompletions().create(chatRequest); chatResponse = futureChat.join(); System.out.println(chatResponse.firstContent()); } public void demoCallChatWithVisionExternalImage() { - var chatRequest = ChatRequest.builder() + chatRequest = ChatRequest.builder() .model(model) .messages(List.of( UserMessage.of(List.of( @@ -123,15 +134,15 @@ public void demoCallChatWithVisionExternalImage() { ContentPartImageUrl.of(ImageUrl.of( "https://upload.wikimedia.org/wikipedia/commons/e/eb/Machu_Picchu%2C_Peru.jpg")))))) .temperature(0.0) - .maxCompletionTokens(500) + .maxTokens(500) .build(); var chatResponse = chatProvider.chatCompletions().createStream(chatRequest).join(); - chatResponse.forEach(ChatDemo::processResponseChunk); + chatResponse.forEach(this::processResponseChunk); System.out.println(); } public void demoCallChatWithVisionLocalImage() { - var chatRequest = ChatRequest.builder() + chatRequest = ChatRequest.builder() .model(model) .messages(List.of( UserMessage.of(List.of( @@ -140,15 +151,15 @@ public void demoCallChatWithVisionLocalImage() { ContentPartImageUrl.of(ImageUrl.of( Base64Util.encode("src/demo/resources/machupicchu.jpg", MediaType.IMAGE))))))) .temperature(0.0) - .maxCompletionTokens(500) + .maxTokens(500) .build(); var chatResponse = chatProvider.chatCompletions().createStream(chatRequest).join(); - chatResponse.forEach(ChatDemo::processResponseChunk); + chatResponse.forEach(this::processResponseChunk); System.out.println(); } public void demoCallChatWithStructuredOutputs() { - var chatRequest = ChatRequest.builder() + chatRequest = ChatRequest.builder() .model(model) .message(SystemMessage .of("You are a helpful math tutor. Guide the user through the solution step by step.")) @@ -159,7 +170,7 @@ public void demoCallChatWithStructuredOutputs() { .build())) .build(); var chatResponse = chatProvider.chatCompletions().createStream(chatRequest).join(); - chatResponse.forEach(ChatDemo::processResponseChunk); + chatResponse.forEach(this::processResponseChunk); System.out.println(); } @@ -171,7 +182,7 @@ public void demoCallChatWithStructuredOutputs2() { } catch (IOException e) { e.printStackTrace(); } - var chatRequest = ChatRequest.builder() + chatRequest = ChatRequest.builder() .model(model) .message(SystemMessage .of("You are a helpful math tutor. Guide the user through the solution step by step.")) @@ -182,7 +193,7 @@ public void demoCallChatWithStructuredOutputs2() { .build())) .build(); var chatResponse = chatProvider.chatCompletions().createStream(chatRequest).join(); - chatResponse.forEach(ChatDemo::processResponseChunk); + chatResponse.forEach(this::processResponseChunk); System.out.println(); } @@ -219,7 +230,7 @@ public void demoCallChatWithAudioInputOutput() { System.out.println("Answer 2: " + audio.getTranscript()); } - private static void processResponseChunk(Chat responseChunk) { + private void processResponseChunk(Chat responseChunk) { var choices = responseChunk.getChoices(); if (!choices.isEmpty()) { var delta = choices.get(0).getMessage(); @@ -234,6 +245,16 @@ private static void processResponseChunk(Chat responseChunk) { } } + private void sleep() { + if (this.sleepSeconds > 0) { + try { + TimeUnit.SECONDS.sleep(this.sleepSeconds); + } catch (InterruptedException e) { + java.lang.Thread.currentThread().interrupt(); + } + } + } + public static class Weather implements Functional { @JsonPropertyDescription("City and state, for example: León, Guanajuato") @@ -292,7 +313,7 @@ public static class Step { } public static void main(String[] args) { - var demo = new ChatDemo("standard", "gpt-4o-mini", "gpt-4o-audio-preview"); + var demo = new ChatDemo("gpt-4o-mini", "gpt-4o-audio-preview"); demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming); demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking); diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ChatMistralDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ChatMistralDemo.java new file mode 100644 index 0000000..d2e7eae --- /dev/null +++ b/src/demo/java/io/github/sashirestela/openai/demo/ChatMistralDemo.java @@ -0,0 +1,23 @@ +package io.github.sashirestela.openai.demo; + +public class ChatMistralDemo extends ChatDemo { + + public ChatMistralDemo(String model) { + super("mistral", model, null); + this.sleepSeconds = 1; //Free tier limit: 1 request per 1 second + this.chatProvider = this.openAIMistral; + } + + public static void main(String[] args) { + var demo = new ChatMistralDemo("pixtral-12b-2409"); + + demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming); + demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking); + demo.addTitleAction("Call Chat with Functions", demo::demoCallChatWithFunctions); + demo.addTitleAction("Call Chat with Vision (External image)", demo::demoCallChatWithVisionExternalImage); + demo.addTitleAction("Call Chat with Vision (Local image)", demo::demoCallChatWithVisionLocalImage); + + demo.run(); + } + +} diff --git a/src/demo/java/io/github/sashirestela/openai/demo/EmbeddingDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/EmbeddingDemo.java index c57fed2..12f0db6 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/EmbeddingDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/EmbeddingDemo.java @@ -3,19 +3,33 @@ import io.github.sashirestela.openai.domain.embedding.EmbeddingBase64; import io.github.sashirestela.openai.domain.embedding.EmbeddingFloat; import io.github.sashirestela.openai.domain.embedding.EmbeddingRequest; +import io.github.sashirestela.openai.service.EmbeddingServices; import java.util.Arrays; public class EmbeddingDemo extends AbstractDemo { + protected String model; + protected EmbeddingServices embeddingProvider; + + public EmbeddingDemo(String model) { + this("standard", model); + } + + protected EmbeddingDemo(String provider, String model) { + super(provider); + this.model = model; + this.embeddingProvider = this.openAI; + } + public void demoCallEmbeddingFloat() { var embeddingRequest = EmbeddingRequest.builder() - .model("text-embedding-ada-002") + .model(this.model) .input(Arrays.asList( "shiny sun", "blue sky")) .build(); - var futureEmbedding = openAI.embeddings().create(embeddingRequest); + var futureEmbedding = embeddingProvider.embeddings().create(embeddingRequest); var embeddingResponse = futureEmbedding.join(); embeddingResponse.getData() .stream() @@ -25,12 +39,12 @@ public void demoCallEmbeddingFloat() { public void demoCallEmbeddingBase64() { var embeddingRequest = EmbeddingRequest.builder() - .model("text-embedding-ada-002") + .model(this.model) .input(Arrays.asList( "shiny sun", "blue sky")) .build(); - var futureEmbedding = openAI.embeddings().createBase64(embeddingRequest); + var futureEmbedding = embeddingProvider.embeddings().createBase64(embeddingRequest); var embeddingResponse = futureEmbedding.join(); embeddingResponse.getData() .stream() @@ -39,7 +53,7 @@ public void demoCallEmbeddingBase64() { } public static void main(String[] args) { - var demo = new EmbeddingDemo(); + var demo = new EmbeddingDemo("text-embedding-3-small"); demo.addTitleAction("Call Embedding Float Format", demo::demoCallEmbeddingFloat); demo.addTitleAction("Call Embedding Base64 Format", demo::demoCallEmbeddingBase64); diff --git a/src/demo/java/io/github/sashirestela/openai/demo/EmbeddingMistralDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/EmbeddingMistralDemo.java new file mode 100644 index 0000000..e39d265 --- /dev/null +++ b/src/demo/java/io/github/sashirestela/openai/demo/EmbeddingMistralDemo.java @@ -0,0 +1,18 @@ +package io.github.sashirestela.openai.demo; + +public class EmbeddingMistralDemo extends EmbeddingDemo { + + public EmbeddingMistralDemo(String model) { + super("mistral", model); + this.embeddingProvider = this.openAIMistral; + } + + public static void main(String[] args) { + var demo = new EmbeddingMistralDemo("mistral-embed"); + + demo.addTitleAction("Call Embedding Float Format", demo::demoCallEmbeddingFloat); + + demo.run(); + } + +} diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ModelDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ModelDemo.java index 43f8d5b..ed3a1d4 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/ModelDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/ModelDemo.java @@ -1,11 +1,23 @@ package io.github.sashirestela.openai.demo; +import io.github.sashirestela.openai.service.ModelServices; + public class ModelDemo extends AbstractDemo { - private String modelId; + protected String modelId; + protected ModelServices modelProvider; + + public ModelDemo() { + this("standard"); + } + + protected ModelDemo(String provider) { + super(provider); + this.modelProvider = this.openAI; + } public void demoGetModels() { - var futureModels = openAI.models().getList(); + var futureModels = modelProvider.models().getList(); var models = futureModels.join(); models.forEach(System.out::println); @@ -13,14 +25,14 @@ public void demoGetModels() { } public void demoGetModel() { - var futureModel = openAI.models().getOne(modelId); + var futureModel = modelProvider.models().getOne(modelId); var model = futureModel.join(); System.out.println(model); } @SuppressWarnings("unused") public void demoDeleteModel() { - var futureModel = openAI.models().delete(modelId); + var futureModel = modelProvider.models().delete(modelId); try { var deleted = futureModel.join(); } catch (Exception e) { diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ModelMistralDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ModelMistralDemo.java new file mode 100644 index 0000000..4efac7d --- /dev/null +++ b/src/demo/java/io/github/sashirestela/openai/demo/ModelMistralDemo.java @@ -0,0 +1,20 @@ +package io.github.sashirestela.openai.demo; + +public class ModelMistralDemo extends ModelDemo { + + public ModelMistralDemo() { + super("mistral"); + this.modelProvider = this.openAIMistral; + } + + public static void main(String[] args) { + var demo = new ModelMistralDemo(); + + demo.addTitleAction("List of All Models", demo::demoGetModels); + demo.addTitleAction("First Model in List", demo::demoGetModel); + demo.addTitleAction("Trying to Delete a Model", demo::demoDeleteModel); + + demo.run(); + } + +} diff --git a/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java b/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java index 5bfba5b..fb8a1ae 100644 --- a/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java +++ b/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java @@ -5,6 +5,7 @@ import io.github.sashirestela.openai.base.OpenAIConfigurator; import io.github.sashirestela.openai.base.OpenAIProvider; import io.github.sashirestela.openai.base.RealtimeConfig; +import io.github.sashirestela.openai.exception.SimpleOpenAIException; import io.github.sashirestela.openai.service.AssistantServices; import io.github.sashirestela.openai.service.AudioServices; import io.github.sashirestela.openai.service.BatchServices; @@ -204,6 +205,9 @@ private Map makeHeaders() { headers.put(Constant.OPENAI_ORG_HEADER, organizationId); } if (projectId != null) { + if (organizationId == null) { + throw new SimpleOpenAIException("OrganizationId should be provided if ProjectId is provided."); + } headers.put(Constant.OPENAI_PRJ_HEADER, projectId); } return headers; diff --git a/src/main/java/io/github/sashirestela/openai/SimpleOpenAIMistral.java b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIMistral.java new file mode 100644 index 0000000..96537b9 --- /dev/null +++ b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIMistral.java @@ -0,0 +1,113 @@ +package io.github.sashirestela.openai; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.github.sashirestela.cleverclient.http.HttpRequestData; +import io.github.sashirestela.cleverclient.support.ContentType; +import io.github.sashirestela.openai.OpenAI.ChatCompletions; +import io.github.sashirestela.openai.OpenAI.Embeddings; +import io.github.sashirestela.openai.OpenAI.Models; +import io.github.sashirestela.openai.base.ClientConfig; +import io.github.sashirestela.openai.base.OpenAIConfigurator; +import io.github.sashirestela.openai.base.OpenAIProvider; +import io.github.sashirestela.openai.service.ChatCompletionServices; +import io.github.sashirestela.openai.service.EmbeddingServices; +import io.github.sashirestela.openai.service.ModelServices; +import io.github.sashirestela.openai.support.Constant; +import lombok.Builder; +import lombok.NonNull; +import lombok.experimental.SuperBuilder; + +import java.net.http.HttpClient; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.function.UnaryOperator; + +/** + * The Mistral OpenAI provider which implements a subset of the standard services. + */ +public class SimpleOpenAIMistral extends OpenAIProvider implements + ChatCompletionServices, + EmbeddingServices, + ModelServices { + + /** + * Constructor used to generate a builder. + * + * @param apiKey Identifier to be used for authentication. Mandatory. + * @param baseUrl Host's url. Optional. + * @param httpClient A {@link java.net.http.HttpClient HttpClient} object. One is created by + * default if not provided. Optional. + * @param objectMapper Provides Json conversions either to and from objects. Optional. + */ + @Builder + public SimpleOpenAIMistral(@NonNull String apiKey, String baseUrl, HttpClient httpClient, + ObjectMapper objectMapper) { + super(MistralConfigurator.builder() + .apiKey(apiKey) + .baseUrl(baseUrl) + .httpClient(httpClient) + .objectMapper(objectMapper) + .build()); + } + + @Override + public ChatCompletions chatCompletions() { + return getOrCreateService(OpenAI.ChatCompletions.class); + } + + @Override + public Embeddings embeddings() { + return getOrCreateService(OpenAI.Embeddings.class); + } + + @Override + public Models models() { + return getOrCreateService(OpenAI.Models.class); + } + + @SuperBuilder + static class MistralConfigurator extends OpenAIConfigurator { + + @Override + public ClientConfig buildConfig() { + return ClientConfig.builder() + .baseUrl(Optional.ofNullable(baseUrl).orElse(Constant.MISTRAL_BASE_URL)) + .headers(Map.of(Constant.AUTHORIZATION_HEADER, Constant.BEARER_AUTHORIZATION + apiKey)) + .httpClient(httpClient) + .requestInterceptor(makeRequestInterceptor()) + .objectMapper(objectMapper) + .build(); + } + + private UnaryOperator makeRequestInterceptor() { + return request -> { + var contentType = request.getContentType(); + if (contentType != null && contentType.equals(ContentType.APPLICATION_JSON)) { + var body = makeNewBody(request); + request.setBody(body); + } + return request; + }; + } + + private String makeNewBody(HttpRequestData request) { + Map mapRegexReplace = new HashMap<>(); + mapRegexReplace.put(",\\s*\"stream_options\"\\s*:\\s*\\{[^{}]*\\}", ""); // Remove "stream_options" + mapRegexReplace.put(",?\\s*\"additionalProperties\"\\s*:\\s*false\\s*", ""); // Remove "additionalProperties" + mapRegexReplace.put(",?\\s*\"strict\"\\s*:\\s*true\\s*", ""); // Remove "strict" + mapRegexReplace.put(",\\s*,", ","); // Replace double commas by one comma + mapRegexReplace.put(",\\s*}", "}"); // Replace trailing comma by closing brace + mapRegexReplace.put("\"index\"\\s*:\\s*null\\s*,\\s*", ""); // Remove "index: null" + mapRegexReplace.put(",\\s*\"refusal\"\\s*:\\s*null", ""); // Remove "refusal: null" + mapRegexReplace.put(",\\s*\"audio\"\\s*:\\s*null", ""); // Remove "audio: null" + var body = (String) request.getBody(); + for (var entry : mapRegexReplace.entrySet()) { + body = body.replaceAll(entry.getKey(), entry.getValue()); + } + return body; + } + + } + +} diff --git a/src/main/java/io/github/sashirestela/openai/support/Constant.java b/src/main/java/io/github/sashirestela/openai/support/Constant.java index 139bd6f..df8f7d8 100644 --- a/src/main/java/io/github/sashirestela/openai/support/Constant.java +++ b/src/main/java/io/github/sashirestela/openai/support/Constant.java @@ -22,4 +22,6 @@ private Constant() { public static final String AZURE_APIKEY_HEADER = "api-key"; public static final String AZURE_API_VERSION = "api-version"; + public static final String MISTRAL_BASE_URL = "https://api.mistral.ai"; + } diff --git a/src/test/java/io/github/sashirestela/openai/SimpleUncheckedExceptionTest.java b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIExceptionTest.java similarity index 96% rename from src/test/java/io/github/sashirestela/openai/SimpleUncheckedExceptionTest.java rename to src/test/java/io/github/sashirestela/openai/SimpleOpenAIExceptionTest.java index 0d63c07..5d598f2 100644 --- a/src/test/java/io/github/sashirestela/openai/SimpleUncheckedExceptionTest.java +++ b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIExceptionTest.java @@ -6,7 +6,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; -class SimpleUncheckedExceptionTest { +class SimpleOpenAIExceptionTest { @Test void shouldSetMessageWhenItIsPassedAsTheOnlyOneArgument() { diff --git a/src/test/java/io/github/sashirestela/openai/SimpleOpenAIMistralTest.java b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIMistralTest.java new file mode 100644 index 0000000..68f37f2 --- /dev/null +++ b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIMistralTest.java @@ -0,0 +1,60 @@ +package io.github.sashirestela.openai; + +import io.github.sashirestela.cleverclient.http.HttpRequestData; +import io.github.sashirestela.cleverclient.support.ContentType; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +class SimpleOpenAIMistralTest { + + @Test + void shouldCreateEndpoints() { + var openAI = SimpleOpenAIMistral.builder() + .apiKey("apiKey") + .baseUrl("baseUrl") + .build(); + assertNotNull(openAI.chatCompletions()); + assertNotNull(openAI.embeddings()); + assertNotNull(openAI.models()); + } + + @Test + void shouldInterceptRequest() { + var actualBody = readJsonFile("src/test/resources/mistral_body_actual.json"); + var actualRequest = HttpRequestData.builder() + .url("url") + .contentType(ContentType.APPLICATION_JSON) + .body(actualBody) + .build(); + var expectedBody = readJsonFile("src/test/resources/mistral_body_expected.json"); + var expectedRequest = HttpRequestData.builder() + .url("url") + .contentType(ContentType.APPLICATION_JSON) + .body(expectedBody) + .build(); + var clientConfig = SimpleOpenAIMistral.MistralConfigurator.builder() + .apiKey("apiKey") + .baseUrl("url") + .build() + .buildConfig(); + actualRequest = clientConfig.getRequestInterceptor().apply(actualRequest); + assertEquals(expectedRequest.getBody(), actualRequest.getBody()); + } + + private String readJsonFile(String filePath) { + String json; + try { + json = Files.readAllLines(Paths.get(filePath)).get(0); + } catch (IOException e) { + json = null; + } + return json; + } + +} diff --git a/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java b/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java index 23f52af..3e5cc85 100644 --- a/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java +++ b/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java @@ -5,6 +5,7 @@ import io.github.sashirestela.openai.base.RealtimeConfig; import io.github.sashirestela.openai.domain.chat.ChatMessage.UserMessage; import io.github.sashirestela.openai.domain.chat.ChatRequest; +import io.github.sashirestela.openai.exception.SimpleOpenAIException; import io.github.sashirestela.openai.support.Constant; import io.github.sashirestela.slimvalidator.exception.ConstraintViolationException; import org.junit.jupiter.api.Test; @@ -75,6 +76,16 @@ void shouldCreateConfigWithDefaultValuesWhenRequiredParametersArePassed() { assertNull(clientConfig.getRequestInterceptor()); } + @Test + void shouldThrowExceptionWhenProjectIdIsProvidedAndOrganizationIdIsNot() { + var configurator = SimpleOpenAI.StandardConfigurator.builder() + .apiKey("apiKey") + .projectId("projectId") + .build(); + var exception = assertThrows(SimpleOpenAIException.class, () -> configurator.buildConfig()); + assertEquals("OrganizationId should be provided if ProjectId is provided.", exception.getMessage()); + } + @Test @SuppressWarnings("unchecked") void shouldNotDuplicateContentTypeHeaderWhenCallingSimpleOpenAI() { diff --git a/src/test/resources/mistral_body_actual.json b/src/test/resources/mistral_body_actual.json new file mode 100644 index 0000000..7d61656 --- /dev/null +++ b/src/test/resources/mistral_body_actual.json @@ -0,0 +1 @@ +{"messages":[{"role":"user","content":"What is the product of 123 and 456?"},{"role":"assistant","content":null,"tool_calls":[{"index":null,"id":"call_83EGwRvtPRUOQUsr3jGyHxAr","type":"function","function":{"name":"product","arguments":"{\"multiplicand\":123,\"multiplier\":456}"}}],"refusal":null,"audio":null},{"role":"tool","content":"56088.0","tool_call_id":"call_83EGwRvtPRUOQUsr3jGyHxAr"}],"model":"gpt-4o-mini","stream":true,"stream_options":{"include_usage":true},"temperature":0.0,"tools":[{"type":"function","function":{"name":"product","description":"Get the product of two numbers","parameters":{"type":"object","properties":{"multiplicand":{"type":"number","description":"The multiplicand part of a product"},"multiplier":{"type":"number","description":"The multiplier part of a product"}},"required":["multiplicand","multiplier"],"additionalProperties":false},"strict":true}}],"tool_choice":"auto"} \ No newline at end of file diff --git a/src/test/resources/mistral_body_expected.json b/src/test/resources/mistral_body_expected.json new file mode 100644 index 0000000..35114f2 --- /dev/null +++ b/src/test/resources/mistral_body_expected.json @@ -0,0 +1 @@ +{"messages":[{"role":"user","content":"What is the product of 123 and 456?"},{"role":"assistant","content":null,"tool_calls":[{"id":"call_83EGwRvtPRUOQUsr3jGyHxAr","type":"function","function":{"name":"product","arguments":"{\"multiplicand\":123,\"multiplier\":456}"}}]},{"role":"tool","content":"56088.0","tool_call_id":"call_83EGwRvtPRUOQUsr3jGyHxAr"}],"model":"gpt-4o-mini","stream":true,"temperature":0.0,"tools":[{"type":"function","function":{"name":"product","description":"Get the product of two numbers","parameters":{"type":"object","properties":{"multiplicand":{"type":"number","description":"The multiplicand part of a product"},"multiplier":{"type":"number","description":"The multiplier part of a product"}},"required":["multiplicand","multiplier"]}}}],"tool_choice":"auto"} \ No newline at end of file